Skip to content

Commit 7b3a1f2

Browse files
Perf optimization for topK permute/unpermute; Add ReadME.
Co-authored-by: Shiqing Fan <shiqingf@nvidia.com>
1 parent 7a7f018 commit 7b3a1f2

File tree

11 files changed

+1404
-249
lines changed

11 files changed

+1404
-249
lines changed

README.md

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
<div align="center">
2+
3+
Grouped GEMM for MoE
4+
===========================
5+
<h4>A PyTorch Toolbox for Grouped GEMM in MoE Model Training</h4>
6+
7+
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
8+
9+
<div align="left">
10+
11+
- [Grouped GEMM for MoE](#grouped-gemm-for-moe)
12+
- [Steps for Using](#steps-for-using)
13+
- [pip install](#pip-install)
14+
- [Build from Source](#build-from-source)
15+
- [Support Matrix](#support-matrix)
16+
- [permute \& unpermute](#permute--unpermute)
17+
- [Ops Usage](#ops-usage)
18+
- [permute](#permute)
19+
- [Parameters](#parameters)
20+
- [unpermute](#unpermute)
21+
- [Parameters](#parameters-1)
22+
- [Example](#example)
23+
24+
---
25+
26+
# Steps for Using
27+
28+
## pip install
29+
```bash
30+
pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main
31+
```
32+
33+
## Build from Source
34+
```bash
35+
git submodule update --init --recursive
36+
mkdir build
37+
cd build
38+
cmake ..
39+
make -j
40+
cd ..
41+
42+
# GroupedGEMM ops test
43+
python grouped_gemm/ops_test.py
44+
45+
# topK permute & unpermute ops test
46+
python grouped_gemm/permute_test.py
47+
48+
# sinkhorn kernel test
49+
python grouped_gemm/sinkhorn_test.py
50+
```
51+
52+
# Support Matrix
53+
54+
## permute & unpermute
55+
56+
| GPU Arch | FP32 | FP16 | BF16 | FP8 |
57+
| :--------- | :---: | :---: | :---: | :---: |
58+
| SM 70 | Y | Y | . | Y |
59+
| SM 75 | Y | Y | . | Y |
60+
| SM 80 | Y | Y | Y | Y |
61+
| SM 86 | Y | Y | Y | Y |
62+
| SM 89 | Y | Y | Y | Y |
63+
| SM 90 | Y | Y | Y | Y |
64+
65+
# Ops Usage
66+
67+
## permute
68+
69+
> ```py
70+
> grouped_gemm.ops.permute(
71+
> input_act: torch.Tensor,
72+
> indices: torch.Tensor,
73+
> max_token_num=0: int) -> tuple
74+
> ```
75+
76+
The output tuple of `(torch.Tensor, torch.Tensor)` that contains two tensors `permuted_act` and `row_id_map`.
77+
78+
* `permuted_act` is the permutation of the original tensor `input_act` with its first dimension permuted according to `indices`.
79+
* `row_id_map` is the mapping table for the row indices of the input activations before and after `grouped_gemm.ops.permute`, which is used for the following `unpermute` op.
80+
81+
### Parameters
82+
83+
* **input_act** (torch.Tensor)
84+
&emsp;shape = [tokens_num, hidden_size]
85+
&emsp;The input activations with each row (token) corresponds to topK experts.
86+
87+
* **indices** (torch.Tensor)
88+
&emsp;shape = [tokens_num, topK_num]
89+
&emsp;The topK expert indices for each row (token) of activations. The `int32` type is recommended.
90+
91+
* **max_token_num** (int)
92+
&emsp;The maximum number of tokens (rows) used for workspace pre-allocation.
93+
94+
<p align="center"><img src=figures/figure_permute.png></p>
95+
96+
## unpermute
97+
98+
> ```py
99+
> grouped_gemm.ops.unpermute(
100+
> input_act: torch.Tensor,
101+
> row_id_map: torch.Tensor,
102+
> probs) -> torch.Tensor
103+
> ```
104+
105+
The mirror operator of `grouped_gemm.ops.permute`.
106+
107+
### Parameters
108+
109+
* **input_act** (torch.Tensor)
110+
&emsp;shape = [tokens_num * topK_num, hidden_size]
111+
&emsp;The permuted activations produced by `grouped_gemm.ops.permute`.
112+
113+
* **row_id_map** (torch.Tensor)
114+
&emsp;shape = [tokens_num * topK_num]
115+
&emsp;The mapping table for the row indices of the activations before and after `grouped_gemm.ops.permute`. The second output tensor of `grouped_gemm.ops.permute`.
116+
117+
* **probs** (torch.Tensor)
118+
&emsp;shape = [tokens_num, topK_num]
119+
&emsp;Sum weights for same-origin tokens from different experts.
120+
121+
<p align="center"><img src=figures/figure_unpermute.png></p>
122+
123+
### Example
124+
125+
```py
126+
import torch
127+
from grouped_gemm import permute, unpermute
128+
129+
indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda')
130+
input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda')
131+
probs = torch.ones_like(indices, dtype=torch.float32)
132+
permuted_inputs, row_id_map = permute(input_act, indices)
133+
unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs)
134+
135+
print(row_id_map)
136+
print(input_act)
137+
print(permuted_inputs)
138+
print(unpermute_outputs)
139+
140+
# Output
141+
# tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32)
142+
# tensor([[0., 0., 0., 0.],
143+
# [1., 1., 1., 1.],
144+
# [2., 2., 2., 2.],
145+
# [3., 3., 3., 3.]], device='cuda:0')
146+
# tensor([[1., 1., 1., 1.],
147+
# [2., 2., 2., 2.],
148+
# [0., 0., 0., 0.],
149+
# [1., 1., 1., 1.],
150+
# [3., 3., 3., 3.],
151+
# [0., 0., 0., 0.],
152+
# [2., 2., 2., 2.],
153+
# [3., 3., 3., 3.]], device='cuda:0')
154+
# tensor([[0., 0., 0., 0.],
155+
# [2., 2., 2., 2.],
156+
# [4., 4., 4., 4.],
157+
# [6., 6., 6., 6.]], device='cuda:0')
158+
```
159+

csrc/ops.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ namespace grouped_gemm {
99
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1010
m.def("gmm", &GroupedGemm, "Grouped GEMM.");
1111
m.def("sinkhorn", &sinkhorn, "Sinkhorn kernel");
12-
m.def("permute", &moe_permute_op, "Token permutation kernel");
13-
m.def("unpermute", &moe_recover_op, "Token un-permutation kernel");
12+
m.def("permute", &moe_permute_topK_op, "Token permutation kernel");
13+
m.def("unpermute", &moe_recover_topK_op, "Token un-permutation kernel");
14+
m.def("unpermute_bwd", &moe_recover_topK_bwd_op, "Token un-permutation backward kernel");
1415
}
1516

1617
} // namespace grouped_gemm

0 commit comments

Comments
 (0)