|
4 | 4 |
|
5 | 5 | 为了方便 PyTorch 自定义算子快速接入 PaddlePaddle 框架,我们提供了如下图所示的兼容机制: |
6 | 6 |
|
7 | | - |
| 7 | +<figure align="center"> |
| 8 | + <img src="https://github.com/PaddlePaddle/docs/blob/develop/docs/guides/custom_op/cross_ecosystem_custom_op/images/cross-ecosystem-custom-op-compatible.drawio.png?raw=true" width="700" alt='missing' align="center"/> |
| 9 | + <figcaption><center>跨生态自定义算子兼容机制示意图</center></figcaption> |
| 10 | +</figure> |
8 | 11 |
|
9 | 12 | 正如图上所示,我们自底向上提供了如下几层支持: |
10 | 13 |
|
@@ -262,23 +265,33 @@ using at::Tensor; |
262 | 265 |
|
263 | 266 | 不过目前兼容层还在持续完善中,部分常见 API 尚未覆盖到,此时就会出现编译错误,你可以根据编译错误提示来定位并修复相关代码。 |
264 | 267 |
|
265 | | -以 `Tensor.reshape` 为例,假设用户在自定义算子中使用了该 API,但 Paddle 没有提供该 API 的兼容实现,就会出现编译错误,此时我们可以选择临时取出 `at::Tensor` 内部的 `paddle::Tensor`,并使用 PaddlePaddle 提供的等效 API 来实现该功能: |
| 268 | +以 `torch::empty` 为例,假设算子库中使用了该 API,但 Paddle 没有提供该 API 的兼容实现,就会出现编译错误: |
| 269 | +
|
| 270 | +```text |
| 271 | +/workspace/cross-ecosystem-custom-op-example/csrc/muladd.cc: In function ‘at::Tensor muladd_cpu(at::Tensor, const at::Tensor&, double)’: |
| 272 | +/workspace/cross-ecosystem-custom-op-example/csrc/muladd.cc:54:30: error: ‘empty’ is not a member of ‘torch’ |
| 273 | + 54 | at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); |
| 274 | + | ^~~~~ |
| 275 | +``` |
| 276 | + |
| 277 | +此时我们可以选择将 PyTorch 的 structs 转换为 Paddle 的 structs,并用 PaddlePaddle 提供的等效 API 来实现该功能: |
| 278 | + |
| 279 | +即将下面的代码: |
266 | 280 |
|
267 | 281 | ```cpp |
268 | 282 | // PyTorch 原代码 |
269 | | -at::IntArrayRef sizes = {2, 3, 4}; |
270 | | -at::Tensor reshaped_tensor = x.reshape(sizes); |
| 283 | +at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); |
271 | 284 | ``` |
272 | 285 |
|
273 | 286 | 我们可以将其替换为: |
274 | 287 |
|
275 | 288 | ```cpp |
276 | 289 | // 替换为 PaddlePaddle 等效实现 |
277 | | -at::IntArrayRef sizes = {2, 3, 4}; |
278 | | -auto paddle_tensor = x._PD_GetInner(); // 获取内部 paddle::Tensor |
279 | | -auto paddle_sizes = sizes._PD_ToPaddleIntArray(); // 转换为 paddle::IntArray |
280 | | -auto paddle_reshaped_tensor = paddle::experimental::reshape(paddle_tensor, paddle_sizes); // 使用 PaddlePaddle reshape API |
281 | | -at::Tensor reshaped_tensor(paddle_reshaped_tensor); // 包装回 at::Tensor |
| 290 | +auto paddle_size = a_contig.sizes()._PD_ToPaddleIntArray(); // 将 PyTorch IntArrayRef 转为 Paddle IntArray |
| 291 | +auto paddle_dtype = compat::_PD_AtenScalarTypeToPhiDataType(a_contig.dtype()); // 将 PyTorch ScalarType 转为 Paddle DataType |
| 292 | +auto paddle_place = a_contig.options()._PD_GetPlace(); // 将 PyTorch Device 转为 Paddle Place |
| 293 | +auto paddle_result = paddle::experimental::empty(paddle_size, paddle_dtype, paddle_place); // 调用 PaddlePaddle 的 empty API |
| 294 | +at::Tensor result(paddle_result); // 将 Paddle Tensor 包装为 PyTorch Tensor |
282 | 295 | ``` |
283 | 296 |
|
284 | 297 | 更多 PaddlePaddle C++ API 的使用方式可参考 [PaddlePaddle C++ 自定义算子文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html)。通过这种方式,你可以逐步修复编译错误,直至自定义算子能够成功编译通过。 |
|
0 commit comments