Skip to content

Commit 85a0c13

Browse files
committed
update design part
1 parent f810ee5 commit 85a0c13

File tree

1 file changed

+295
-1
lines changed

1 file changed

+295
-1
lines changed

docs/guides/custom_op/cross_ecosystem_custom_op/design_and_migration_cn.md

Lines changed: 295 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,298 @@
1919

2020
## 迁移步骤
2121

22-
下面我们以一个简单的 PyTorch 自定义算子为例,介绍如何将其迁移至 PaddlePaddle 框架。
22+
下面我们以一个简单的 PyTorch 自定义算子为例,介绍如何将其迁移至 PaddlePaddle 框架。相关代码见 [PFCCLab/cross-ecosystem-custom-op-example](https://github.com/PFCCLab/cross-ecosystem-custom-op-example)
23+
24+
### 搭建 PyTorch 运行环境
25+
26+
在迁移之前,需要先确保 PyTorch 算子能够在本地正确编译和运行。这可以参考具体的仓库说明文档来完成。
27+
28+
本示例展示的自定义算子可以通过如下方式来成功编译和运行:
29+
30+
```bash
31+
# 首先安装 PyTorch,具体命令可以参考 https://pytorch.org/get-started/locally/
32+
pip install torch
33+
# 克隆示例代码仓库
34+
git clone https://github.com/PFCCLab/cross-ecosystem-custom-op-example.git
35+
cd cross-ecosystem-custom-op-example
36+
# 编译自定义算子
37+
pip install . --no-build-isolation
38+
# 运行测试脚本
39+
python test.py
40+
```
41+
42+
很好,我们已经确保该算子能够在 PyTorch 框架下正确运行,接下来将会介绍如何将其迁移至 PaddlePaddle 框架。
43+
44+
### 清理 PyTorch 环境
45+
46+
在迁移之前,建议先卸载 PyTorch 相关的包,或者者新建一个干净的虚拟环境来进行迁移工作,以避免潜在的包冲突问题,并安装 PaddlePaddle 框架,具体命令可参考 [PaddlePaddle 安装指南](https://www.paddlepaddle.org.cn/install/quick)
47+
48+
### 理解源码结构
49+
50+
当前示例代码的目录结构很简单,如下所示:
51+
52+
```text
53+
.
54+
├── csrc
55+
│ └── muladd.cc # 自定义算子实现
56+
├── extension # 自定义算子 Python package
57+
│ └── __init__.py # 自定义算子 Python 封装
58+
├── pyproject.toml # Python package 配置文件,主要描述 build backend
59+
├── README.md
60+
├── setup.py # Python package 构建脚本,用于 build backend setuptools 的调用
61+
└── test.py # 测试脚本
62+
```
63+
64+
从构建流程来看,我们主要关注的是:
65+
66+
- `pyproject.toml` 作为 PEP 517/518 标准的配置文件,描述了该 Python package 的构建后端为 `setuptools`,以及相关的元信息。
67+
68+
```toml
69+
[build-system]
70+
requires = [
71+
"setuptools",
72+
"torch",
73+
]
74+
build-backend = "setuptools.build_meta"
75+
```
76+
77+
- `setup.py` 作为 `setuptools` 的构建脚本,主要负责调用 `torch.utils.cpp_extension` 模块来编译 C++ 源码并生成可供 Python 调用的扩展模块。
78+
79+
```python
80+
from setuptools import setup, find_packages
81+
from torch.utils import cpp_extension
82+
83+
setup(
84+
name="extension",
85+
packages=find_packages(include=['extension']),
86+
ext_modules=[
87+
cpp_extension.CUDAExtension(
88+
name="extension_cpp",
89+
sources=["csrc/muladd.cc"],
90+
)
91+
],
92+
cmdclass={'build_ext': cpp_extension.BuildExtension},
93+
)
94+
```
95+
96+
- `csrc/muladd.cc` 作为自定义算子的核心实现文件,包含了算子的具体逻辑和注册代码,我们往往可以分为三部分:
97+
- 框架无关的算子逻辑实现部分,这部分逻辑并不使用 PyTorch 的 API,仅仅使用 C++/CUDA 标准库来实现。
98+
99+
```cpp
100+
template<typename T>
101+
void muladd_cpu_impl(const T* a_ptr, const T* b_ptr, T c, T* result_ptr, int64_t numel) {
102+
for (int64_t i = 0; i < numel; i++) {
103+
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
104+
}
105+
}
106+
```
107+
- PyTorch C++ API 相关的部分,这部分代码会使用 `at::Tensor` 等 PyTorch C++ API 来进行张量操作和内存管理。
108+
```cpp
109+
at::Tensor muladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
110+
TORCH_CHECK(a.sizes() == b.sizes());
111+
TORCH_CHECK(a.dtype() == at::kFloat);
112+
TORCH_CHECK(b.dtype() == at::kFloat);
113+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
114+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
115+
at::Tensor a_contig = a.contiguous();
116+
at::Tensor b_contig = b.contiguous();
117+
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
118+
const float* a_ptr = a_contig.data_ptr<float>();
119+
const float* b_ptr = b_contig.data_ptr<float>();
120+
float* result_ptr = result.data_ptr<float>();
121+
muladd_cpu_impl<float>(a_ptr, b_ptr, static_cast<float>(c), result_ptr, result.numel());
122+
return result;
123+
}
124+
```
125+
- 算子注册部分,这部分代码会使用 PyTorch 提供的注册宏(如 `TORCH_LIBRARY`)来完成算子的注册工作。
126+
127+
```cpp
128+
extern "C" {
129+
/* Creates a dummy empty _C module that can be imported from Python.
130+
The import from Python will load the .so consisting of this file
131+
in this extension, so that the TORCH_LIBRARY static initializers
132+
below are run. */
133+
PyObject* PyInit_extension_cpp(void) {
134+
static struct PyModuleDef module_def = {
135+
PyModuleDef_HEAD_INIT,
136+
"extension_cpp", /* name of module */
137+
NULL, /* module documentation, may be NULL */
138+
-1, /* size of per-interpreter state of the module,
139+
or -1 if the module keeps state in global variables. */
140+
NULL, /* methods */
141+
};
142+
return PyModule_Create(&module_def);
143+
}
144+
}
145+
146+
TORCH_LIBRARY(extension_cpp, m) {
147+
m.def("muladd_cpp(Tensor a, Tensor b, float c) -> Tensor");
148+
}
149+
150+
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
151+
m.impl("muladd_cpp", &muladd_cpu);
152+
}
153+
```
154+
155+
156+
从执行流程来看,在 Python 端调用该自定义算子时,主要经历了如下几个步骤:
157+
158+
- 在 `test.py` 中导入自定义算子 Python package `extension`。
159+
- 在 `extension/__init__.py` 中通过 `torch.ops.extension_cpp.muladd_cpp` 来调用 C++ 扩展模块中的自定义算子,从而调用到上面注册的 `muladd_cpp` 算子。
160+
161+
```python
162+
def muladd(a: torch.Tensor, b: torch.Tensor, c: float) -> torch.Tensor:
163+
return torch.ops.extension_cpp.muladd_cpp(a, b, c)
164+
```
165+
166+
### 调整构建脚本,使用 PaddlePaddle 编译自定义算子
167+
168+
由于原本的构建脚本是基于 PyTorch 的 `torch.utils.cpp_extension` 模块来完成编译的,因此我们需要将其替换为 PaddlePaddle 提供的自定义算子编译方式。
169+
170+
由于我们提供了 `paddle.compat.enable_torch_proxy()` 代理层来兼容 PyTorch 的 C++ API,因此我们可以使用该 API 实现 torch API 的一键兼容调用。
171+
172+
```diff
173+
+import paddle
174+
+paddle.compat.enable_torch_proxy() # Enable torch proxy globally
175+
176+
from setuptools import setup, find_packages
177+
# 如下的 torch extension 已经被 PaddlePaddle 的同等功能替代(即 paddle.utils.cpp_extension)
178+
# 下面的代码完全不需要修改即可运行
179+
from torch.utils import cpp_extension
180+
181+
setup(
182+
name="extension",
183+
packages=find_packages(include=['extension']),
184+
ext_modules=[
185+
cpp_extension.CUDAExtension(
186+
name="extension_cpp",
187+
sources=["csrc/muladd.cc"],
188+
)
189+
],
190+
cmdclass={'build_ext': cpp_extension.BuildExtension},
191+
)
192+
```
193+
194+
对于本示例来说,仅仅需要在 `setup.py` 中添加上述两行代码即可完成迁移工作,其他代码均无需修改。但是自定义算子代码库一般各式各样,可能还需要根据实际情况进行一些调整,关于更多细节请参考 [`paddle.utils.cpp_extension.setup` 文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/utils/cpp_extension/setup_cn.html)
195+
196+
### 尝试编译并修复
197+
198+
完成构建脚本的修改后,即可尝试编译自定义算子:
199+
200+
```bash
201+
pip install . --no-build-isolation
202+
```
203+
204+
由于我们提供了 PyTorch C++ API 兼容层,因此理想情况下大多数用户的自定义算子代码都可以直接通过编译而无需修改。
205+
206+
PyTorch C++ API 兼容层本质上是以 PyTorch C++ API 作为用户调用接口,并在底层映射至 PaddlePaddle C++ API 来实现的。以 `at::Tensor` 为例,你所调用的 `at::Tensor` 实际上是一个代理类,该类内部持有一个 `paddle::Tensor` 对象,并将所有对 `at::Tensor` 的操作映射为对 `paddle::Tensor` 的操作。
207+
208+
```cpp
209+
// paddle/phi/api/include/compat/ATen/core/TensorBody.h
210+
namespace at {
211+
using PaddleTensor = paddle::Tensor;
212+
213+
class Tensor : public TensorBase {
214+
public:
215+
Tensor() = default;
216+
Tensor(const PaddleTensor& tensor) : TensorBase(tensor){}; // NOLINT
217+
Tensor(const Tensor& tensor) = default;
218+
Tensor(Tensor&& tensor) = default;
219+
220+
void* data_ptr() const { return const_cast<void*>(tensor_.data()); }
221+
template <typename T>
222+
T* data_ptr() const {
223+
return const_cast<T*>(tensor_.data<T>());
224+
}
225+
226+
c10::IntArrayRef sizes() const {
227+
return compat::_PD_PhiDDimToIntArrayRef(tensor_.dims());
228+
}
229+
230+
int64_t numel() const { return tensor_.numel(); }
231+
232+
c10::ScalarType dtype() const { // Should we use `TypeMeta` here?
233+
return compat::_PD_PhiDataTypeToAtenScalarType(tensor_.dtype());
234+
}
235+
236+
c10::Device device() const { return c10::Device(tensor_.place()); }
237+
238+
int64_t dim() const { return tensor_.dims().size(); }
239+
int64_t ndimension() const { return dim(); }
240+
241+
Tensor& fill_(const at::Scalar& value) const {
242+
paddle::experimental::fill_(const_cast<PaddleTensor&>(tensor_), value);
243+
return const_cast<at::Tensor&>(*this);
244+
}
245+
246+
Tensor& zero_() const {
247+
paddle::experimental::fill_(const_cast<PaddleTensor&>(tensor_), 0.0);
248+
return const_cast<at::Tensor&>(*this);
249+
}
250+
251+
PaddleTensor _PD_GetInner() const { return tensor_; }
252+
PaddleTensor& _PD_GetInner() { return tensor_; }
253+
};
254+
255+
} // namespace at
256+
namespace torch {
257+
using at::Tensor;
258+
} // namespace torch
259+
```
260+
261+
完整的兼容层代码见 [`paddle/phi/api/include/compat`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/phi/api/include/compat),我们提供了与 PyTorch C++ API 相同的头文件结构和命名空间,只需按原有方式调用即可。
262+
263+
不过目前兼容层还在持续完善中,部分常见 API 尚未覆盖到,此时就会出现编译错误,你可以根据编译错误提示来定位并修复相关代码。
264+
265+
以 `Tensor.reshape` 为例,假设用户在自定义算子中使用了该 API,但 Paddle 没有提供该 API 的兼容实现,就会出现编译错误,此时我们可以选择临时取出 `at::Tensor` 内部的 `paddle::Tensor`,并使用 PaddlePaddle 提供的等效 API 来实现该功能:
266+
267+
```cpp
268+
// PyTorch 原代码
269+
at::IntArrayRef sizes = {2, 3, 4};
270+
at::Tensor reshaped_tensor = x.reshape(sizes);
271+
```
272+
273+
我们可以将其替换为:
274+
275+
```cpp
276+
// 替换为 PaddlePaddle 等效实现
277+
at::IntArrayRef sizes = {2, 3, 4};
278+
auto paddle_tensor = x._PD_GetInner(); // 获取内部 paddle::Tensor
279+
auto paddle_sizes = shape._PD_ToPaddleIntArray(); // 转换为 paddle::IntArray
280+
auto paddle_reshaped_tensor = paddle::experimental::reshape(paddle_tensor, sizes); // 使用 PaddlePaddle reshape API
281+
at::Tensor reshaped_tensor(paddle_reshaped_tensor); // 包装回 at::Tensor
282+
```
283+
284+
更多 PaddlePaddle C++ API 的使用方式可参考 [PaddlePaddle C++ 自定义算子文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html)。通过这种方式,你可以逐步修复编译错误,直至自定义算子能够成功编译通过。
285+
286+
### 运行测试并修复
287+
288+
完成编译后,即可运行测试脚本来验证自定义算子的正确性,由于原本的测试脚本是基于 PyTorch 框架来实现的,因此我们需要改写测试脚本以适配 PaddlePaddle 框架。
289+
290+
```python
291+
import paddle
292+
paddle.compat.enable_torch_proxy(scope={"extension"}) # 仅启用 extension 包的 torch 代理
293+
import extension
294+
295+
x = paddle.tensor([1.0, 2.0, 3.0])
296+
y = paddle.tensor([4.0, 5.0, 6.0])
297+
z = 2.0
298+
result = extension.muladd(x, y, z)
299+
print(result) # Expected output: tensor([ 6., 12., 20.])
300+
```
301+
302+
由于 `extension` 包中仍然使用了 `torch` 模块下的 Python API,因此我们需要启用 `torch` 代理来确保这些 API 能够正确映射至 PaddlePaddle 框架。为了避免对其他代码产生影响,我们可以通过 `scope` 参数来限定代理的作用范围。
303+
304+
当然,与 C++ 端类似,Python 端的兼容层也还在持续完善中,部分常见 API 尚未覆盖到,如果遇到此类错误,你可以尝试参考 [PaddlePaddle Python API 文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html)[PyTorch 最新 release 与 Paddle develop API 映射表](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.html)来寻找等效的 PaddlePaddle API 并进行替换,直到运行时不再报错且结果正确为止。
305+
306+
至此,一个 PyTorch 自定义算子就成功迁移至 PaddlePaddle 框架了!
307+
308+
### 总结
309+
310+
通过上述步骤,我们介绍了如何将一个简单的 PyTorch 自定义算子迁移至 PaddlePaddle 框架。总体来说,迁移工作主要包括以下几个方面:
311+
312+
- 调整构建脚本,使用 PaddlePaddle 提供的自定义算子编译方式来替换原有的 PyTorch 构建方式。
313+
- 修复 C++ 端的编译错误,主要是由于部分 PyTorch C++ API 尚未覆盖到,需要借助 PaddlePaddle C++ API 来实现等效功能。
314+
- 改写 Python 端的测试脚本,借助 torch proxy 代理层一键替换 PyTorch Python API,并根据实际情况进行部分 API 替换。
315+
316+
目前无论是 C++ 端还是 Python 端的兼容层都还在持续完善中,未来我们会不断补充更多常用 API 的兼容实现,从而进一步降低用户的迁移成本。同时我们也非常欢迎社区用户参与到兼容层的建设中来,共同推动跨生态自定义算子的互通与发展!

0 commit comments

Comments
 (0)