Skip to content

Commit f810ee5

Browse files
committed
tmp commit
1 parent c9fa898 commit f810ee5

File tree

4 files changed

+140
-51
lines changed

4 files changed

+140
-51
lines changed

docs/guides/custom_op/cross_ecosystem_custom_op/cross_ecosystem_custom_op_cn.md renamed to docs/guides/custom_op/cross_ecosystem_custom_op/design_and_migration_cn.md

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
# 跨生态自定义算子接入
1+
## 原理和迁移方式
22

3-
## 概述
4-
5-
随着大模型的兴起,在深度学习框架之上构建自定义算子(Custom Operator)已成为提升模型性能和功能的关键手段。而目前 PyTorch 作为深度学习领域的主流框架之一,拥有大量的自定义算子实现。为了帮助用户更好地将现有的 PyTorch 自定义算子迁移至 PaddlePaddle 框架,我们提供了自定义算子兼容机制,旨在降低迁移成本,提升开发效率。
6-
7-
## 方案介绍
3+
### 实现原理
84

95
为了方便 PyTorch 自定义算子快速接入 PaddlePaddle 框架,我们提供了如下图所示的兼容机制:
106

11-
![PyTorch 自定义算子兼容机制示意图](./images/pytorch-op-compatible.drawio.png)
7+
![跨生态自定义算子兼容机制示意图](./images/cross-ecosystem-custom-op-compatible.drawio.png)
128

139
正如图上所示,我们自底向上提供了如下几层支持:
1410

@@ -24,23 +20,3 @@
2420
## 迁移步骤
2521

2622
下面我们以一个简单的 PyTorch 自定义算子为例,介绍如何将其迁移至 PaddlePaddle 框架。
27-
28-
## 已迁移算子库
29-
30-
### FlashInfer
31-
32-
### FlashMLA
33-
34-
### DeepGEMM
35-
36-
### TileLang
37-
38-
### Triton
39-
40-
### TorchCodec
41-
42-
### DeepEP
43-
44-
coming soon...
45-
46-
## 参考资料

docs/guides/custom_op/cross_ecosystem_custom_op/images/pytorch-op-compatible.drawio.png renamed to docs/guides/custom_op/cross_ecosystem_custom_op/images/cross-ecosystem-custom-op-compatible.drawio.png

File renamed without changes.

docs/guides/custom_op/cross_ecosystem_custom_op/index_cn.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
本章节介绍如何将其他深度学习框架算子生态的自定义算子迁移至飞桨框架,主要分为以下内容:
66

77
- `使用指南 <./user_guide.html>`_ 介绍跨生态自定义算子的使用方法,以及已经成功接入的自定义算子库列表。
8-
- `原理和迁移 <./design.html>`_ 介绍跨生态自定义算子的设计原理,以及如何将其他深度学习框架的自定义算子迁移至飞桨框架。
8+
- `原理和迁移方式 <./design_and_migration_cn.html>`_ 介绍跨生态自定义算子的设计原理,以及如何将其他深度学习框架的自定义算子迁移至飞桨框架。
99

1010
.. toctree::
1111
:hidden:
1212

1313
user_guide_cn.md
14-
design_cn.md
14+
design_and_migration_cn.md

docs/guides/custom_op/cross_ecosystem_custom_op/user_guide_cn.md

Lines changed: 135 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,59 +6,172 @@
66

77
## 使用步骤
88

9-
### 一般安装方式
9+
### 安装方式
1010

11-
对于使用基于兼容性方案的跨生态自定义算子库,一般情况下只需要 clone 后通过 pip 安装对应的算子库即可使用。下面以 FlashMLA 为例说明安装方式:
11+
对于使用基于兼容性方案的跨生态自定义算子库,一般情况下只需要 clone 后通过 pip 安装对应的算子库即可使用。下面以 `FlashInfer` 为例说明安装方式:
1212

1313
```bash
14-
git clone https://github.com/PFCCLab/FlashMLA.git
15-
cd FlashMLA
16-
pip install .
14+
pip install paddlepaddle_gpu # Install PaddlePaddle with GPU support, refer to https://www.paddlepaddle.org.cn/install/quick for more details
15+
git clone https://github.com/PFCCLab/flashinfer.git
16+
cd flashinfer
17+
git submodule update --init
18+
pip install apache-tvm-ffi>=0.1.2 # Use TVM FFI 0.1.2 or above
19+
pip install filelock jinja2 # Install tools for jit compilation
20+
# Install FlashInfer
21+
pip install --no-build-isolation . -v
1722
```
1823

19-
对于部分已经发布到 PyPI 的自定义算子库,也可以直接通过 pip 安装。下面以 TorchCodec 为例:
24+
对于部分已经发布到 PyPI 的自定义算子库,也可以直接通过 pip 安装。下面以 `TorchCodec` 为例:
2025

2126
```bash
2227
pip install paddlecodec
2328
```
2429

25-
个别算子库可能会有特殊的安装方式,请参考对应算子库的说明文档进行安装
30+
个别算子库可能会有特殊的安装方式,请参考对应算子库 repo 中的 `README.md` 进行安装
2631

2732
### 使用方式
2833

2934
安装完成后,即可在代码中直接导入并使用对应的算子库。为了实现跨生态兼容,用户需要在导入算子库之前,先启用 PaddlePaddle 的 PyTorch 代理层,以确保算子库中 `torch` 模块的调用能够正确映射到 `paddle` 模块。下面以 FlashMLA 为例说明使用方式:
3035

3136
```python
37+
# 注意,在导入跨生态自定义算子库之前,需先启用 PaddlePaddle 的 PyTorch 代理层
38+
# 即添加下面的两行
3239
import paddle
3340

34-
paddle.compat.enable_torch_proxy({"flash_mla"})
41+
# scope 为限定代理层生效的模块名称空间,避免影响其他模块的使用
42+
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
3543

36-
import flash_mla
37-
# 之后即可使用 flash_mla 下的算子
44+
# 之后即可导入并使用 flashinfer 库
45+
import flashinfer
46+
# 之后即可使用 flashinfer 下的算子,和 PyTorch 生态下的使用方式一致
47+
48+
# 下面以 flashinfer 中的 RMSNorm 算子为例
49+
import numpy as np
50+
51+
def rms_norm(x, w, eps=1e-6):
52+
orig_dtype = x.dtype
53+
x = x.float()
54+
variance = x.pow(2).mean(dim=-1, keepdim=True)
55+
x = x * paddle.rsqrt(variance + eps)
56+
x = x * w.float()
57+
x = x.to(orig_dtype)
58+
return x
59+
60+
batch_size = 99
61+
hidden_size = 1024
62+
dtype = paddle.float16
63+
64+
x = paddle.randn(batch_size, hidden_size).cuda().to(dtype)
65+
w = paddle.randn(hidden_size).cuda().to(dtype)
66+
67+
y_ref = rms_norm(x, w)
68+
69+
y = flashinfer.norm.rmsnorm(x, w, enable_pdl=False)
70+
71+
# flashinfer 算子输出结果与参考实现保持一致
72+
np.testing.assert_allclose(y_ref, y, rtol=1e-3, atol=1e-3)
3873
```
3974

4075
## 已支持的算子库
4176

4277
PaddlePaddle 官方协同社区已经对社区中主流的跨生态自定义算子库进行了适配和测试,用户可以直接使用这些算子库而无需进行额外的修改。
4378

44-
我们将这些算子库统一放在组织 [PFCCLab](https://github.com/PFCCLab) 下,并列在下方。如果下方列表中没有你需要的算子库,可以移步至[原理和迁移](./design_cn.md),了解自定义算子兼容机制的实现原理,以及如何将你需要的算子库进行迁移。
79+
我们将这些算子库统一放在组织 [PFCCLab](https://github.com/PFCCLab) 下,并列在下方。如果下方列表中没有你需要的算子库,可以移步至[原理和迁移方式](./design_and_migration_cn.md),了解自定义算子兼容机制的实现原理,以及如何将你需要的算子库进行迁移。
4580

46-
### FlashInfer
81+
以下是已经支持的跨生态自定义算子库列表:
4782

48-
#### 安装方式
83+
| 算子库名称 | GitHub repo | PyPI 链接 |
84+
| - | - | - |
85+
| FlashInfer | [PFCCLab/flashinfer](https://github.com/PFCCLab/flashinfer) | - |
86+
| FlashMLA | [PFCCLab/FlashMLA](https://github.com/PFCCLab/FlashMLA) | - |
87+
| DeepGEMM | [PFCCLab/DeepGEMM](https://github.com/PFCCLab/DeepGEMM) | - |
88+
| DeepEP | [PFCCLab/DeepGEMM](https://github.com/PFCCLab/DeepEP) | - |
89+
| TorchCodec | [PFCCLab/paddlecodec](https://github.com/PFCCLab/paddlecodec) | [paddlecodec](https://pypi.org/project/paddlecodec/) |
4990

50-
#### 使用方式
91+
## Kernel DSL 生态支持
5192

52-
### FlashMLA
93+
除去自定义算子外,编写自定义算子的方式也在不断演进,涌现出了诸如 Kernel DSL(如 Triton、TileLang)等新兴的编写方式。这些新兴的编写方式在实现中往往或多或少依赖于特定深度学习框架的状态管理接口,从而导致跨生态迁移的难度加大。为此,我们也致力于提升这些新兴编写方式的跨生态兼容性,帮助用户更好地将其迁移至 PaddlePaddle 框架。
5394

54-
### DeepGEMM
95+
我们目前已经支持的 Kernel DSL 生态包括 Triton 和 TileLang。安装方式分别如下:
5596

56-
### DeepEP
57-
58-
### TorchCodec
97+
```bash
98+
# Triton 直接安装官方包即可
99+
pip install triton
100+
# TileLang 目前仍需要安装我们适配后的版本
101+
pip install tilelang-paddle
102+
```
59103

60-
## 已支持的 Kernel DSL
104+
与其他自定义算子库相同,用户同样需要在导入对应的 Kernel DSL 库之前,先启用 PaddlePaddle 的 PyTorch 代理层。下面以 TileLang 为例说明使用方式:
61105

62-
### TileLang
106+
```python
107+
# 同样,在导入跨生态 Kernel DSL 库之前,需先启用 PaddlePaddle 的 PyTorch 代理层
108+
import paddle
63109

64-
### Triton
110+
# 限定生效范围在 TileLang 模块
111+
paddle.compat.enable_torch_proxy(scope={"tilelang"})
112+
113+
# 之后使用方式与官方 PyTorch 生态下保持一致
114+
@tilelang.jit
115+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
116+
@T.prim_func
117+
def matmul_relu_kernel(
118+
A: T.Tensor((M, K), dtype),
119+
B: T.Tensor((K, N), dtype),
120+
C: T.Tensor((M, N), dtype),
121+
):
122+
# Initialize Kernel Context
123+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
124+
A_shared = T.alloc_shared((block_M, block_K), dtype)
125+
B_shared = T.alloc_shared((block_K, block_N), dtype)
126+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
127+
128+
# Enable rasterization for better L2 cache locality (Optional)
129+
# T.use_swizzle(panel_size=10, enable=True)
130+
131+
# Clear local accumulation
132+
T.clear(C_local)
133+
134+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
135+
# Copy tile of A
136+
# This is a sugar syntax for parallelized copy
137+
T.copy(A[by * block_M, ko * block_K], A_shared)
138+
139+
# Copy tile of B
140+
T.copy(B[ko * block_K, bx * block_N], B_shared)
141+
142+
# Perform a tile-level GEMM on the shared buffers
143+
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
144+
T.gemm(A_shared, B_shared, C_local)
145+
146+
# relu
147+
for i, j in T.Parallel(block_M, block_N):
148+
C_local[i, j] = T.max(C_local[i, j], 0)
149+
150+
# Copy result back to global memory
151+
T.copy(C_local, C[by * block_M, bx * block_N])
152+
153+
return matmul_relu_kernel
154+
155+
M = 1024
156+
N = 1024
157+
K = 1024
158+
block_M = 128
159+
block_N = 128
160+
block_K = 32
161+
162+
# 定义并编译 Kernel 函数
163+
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
164+
165+
# 创建随机输入张量
166+
a = paddle.randn(M, K, device="cuda", dtype=paddle.float16)
167+
b = paddle.randn(K, N, device="cuda", dtype=paddle.float16)
168+
c = paddle.empty(M, N, device="cuda", dtype=paddle.float16)
169+
170+
# 运行 kernel
171+
matmul_relu_kernel(a, b, c)
172+
173+
ref_c = paddle.nn.functional.relu(a @ b)
174+
175+
# 结果对齐
176+
np.testing.assert_allclose(c.numpy(), ref_c.numpy(), rtol=1e-2, atol=1e-2)
177+
```

0 commit comments

Comments
 (0)