|
| 1 | +<div align="center"> |
| 2 | + |
| 3 | +Grouped GEMM for MoE |
| 4 | +=========================== |
| 5 | +<h4>A PyTorch Toolbox for Grouped GEMM in MoE Model Training</h4> |
| 6 | + |
| 7 | +[](./LICENSE) |
| 8 | + |
| 9 | +<div align="left"> |
| 10 | + |
| 11 | +- [Grouped GEMM for MoE](#grouped-gemm-for-moe) |
| 12 | +- [Steps for Using](#steps-for-using) |
| 13 | + - [pip install](#pip-install) |
| 14 | + - [Build from Source](#build-from-source) |
| 15 | +- [Support Matrix](#support-matrix) |
| 16 | + - [permute \& unpermute](#permute--unpermute) |
| 17 | +- [Ops Usage](#ops-usage) |
| 18 | + - [permute](#permute) |
| 19 | + - [Parameters](#parameters) |
| 20 | + - [unpermute](#unpermute) |
| 21 | + - [Parameters](#parameters-1) |
| 22 | + - [Example](#example) |
| 23 | + |
| 24 | +--- |
| 25 | + |
| 26 | +# Steps for Using |
| 27 | + |
| 28 | +## pip install |
| 29 | +```bash |
| 30 | +pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main |
| 31 | +``` |
| 32 | + |
| 33 | +## Build from Source |
| 34 | +```bash |
| 35 | +git submodule update --init --recursive |
| 36 | +mkdir build |
| 37 | +cd build |
| 38 | +cmake .. |
| 39 | +make -j |
| 40 | +cd .. |
| 41 | + |
| 42 | +# GroupedGEMM ops test |
| 43 | +python grouped_gemm/ops_test.py |
| 44 | + |
| 45 | +# topK permute & unpermute ops test |
| 46 | +python grouped_gemm/permute_test.py |
| 47 | + |
| 48 | +# sinkhorn kernel test |
| 49 | +python grouped_gemm/sinkhorn_test.py |
| 50 | +``` |
| 51 | + |
| 52 | +# Support Matrix |
| 53 | + |
| 54 | +## permute & unpermute |
| 55 | + |
| 56 | +| GPU Arch | FP32 | FP16 | BF16 | FP8 | |
| 57 | +| :--------- | :---: | :---: | :---: | :---: | |
| 58 | +| SM 70 | Y | Y | . | Y | |
| 59 | +| SM 75 | Y | Y | . | Y | |
| 60 | +| SM 80 | Y | Y | Y | Y | |
| 61 | +| SM 86 | Y | Y | Y | Y | |
| 62 | +| SM 89 | Y | Y | Y | Y | |
| 63 | +| SM 90 | Y | Y | Y | Y | |
| 64 | + |
| 65 | +# Ops Usage |
| 66 | + |
| 67 | +## permute |
| 68 | + |
| 69 | +> ```py |
| 70 | +> grouped_gemm.ops.permute( |
| 71 | +> input_act: torch.Tensor, |
| 72 | +> indices: torch.Tensor, |
| 73 | +> max_token_num=0: int) -> tuple |
| 74 | +> ``` |
| 75 | +
|
| 76 | +The output tuple of `(torch.Tensor, torch.Tensor)` that contains two tensors `permuted_act` and `row_id_map`. |
| 77 | +
|
| 78 | +* `permuted_act` is the permutation of the original tensor `input_act` with its first dimension permuted according to `indices`. |
| 79 | +* `row_id_map` is the mapping table for the row indices of the input activations before and after `grouped_gemm.ops.permute`, which is used for the following `unpermute` op. |
| 80 | +
|
| 81 | +### Parameters |
| 82 | +
|
| 83 | +* **input_act** (torch.Tensor) |
| 84 | +  shape = [tokens_num, hidden_size] |
| 85 | +  The input activations with each row (token) corresponds to topK experts. |
| 86 | +
|
| 87 | +* **indices** (torch.Tensor) |
| 88 | +  shape = [tokens_num, topK_num] |
| 89 | +  The topK expert indices for each row (token) of activations. The `int32` type is recommended. |
| 90 | +
|
| 91 | +* **max_token_num** (int) |
| 92 | +  The maximum number of tokens (rows) used for workspace pre-allocation. |
| 93 | +
|
| 94 | +<p align="center"><img src=figures/figure_permute.png></p> |
| 95 | +
|
| 96 | +## unpermute |
| 97 | +
|
| 98 | +> ```py |
| 99 | +> grouped_gemm.ops.unpermute( |
| 100 | +> input_act: torch.Tensor, |
| 101 | +> row_id_map: torch.Tensor, |
| 102 | +> probs) -> torch.Tensor |
| 103 | +> ``` |
| 104 | +
|
| 105 | +The mirror operator of `grouped_gemm.ops.permute`. |
| 106 | +
|
| 107 | +### Parameters |
| 108 | +
|
| 109 | +* **input_act** (torch.Tensor) |
| 110 | +  shape = [tokens_num * topK_num, hidden_size] |
| 111 | +  The permuted activations produced by `grouped_gemm.ops.permute`. |
| 112 | +
|
| 113 | +* **row_id_map** (torch.Tensor) |
| 114 | +  shape = [tokens_num * topK_num] |
| 115 | +  The mapping table for the row indices of the activations before and after `grouped_gemm.ops.permute`. The second output tensor of `grouped_gemm.ops.permute`. |
| 116 | +
|
| 117 | +* **probs** (torch.Tensor) |
| 118 | +  shape = [tokens_num, topK_num] |
| 119 | +  Sum weights for same-origin tokens from different experts. |
| 120 | +
|
| 121 | +<p align="center"><img src=figures/figure_unpermute.png></p> |
| 122 | +
|
| 123 | +### Example |
| 124 | +
|
| 125 | +```py |
| 126 | +import torch |
| 127 | +from grouped_gemm import permute, unpermute |
| 128 | +
|
| 129 | +indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda') |
| 130 | +input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda') |
| 131 | +probs = torch.ones_like(indices, dtype=torch.float32) |
| 132 | +permuted_inputs, row_id_map = permute(input_act, indices) |
| 133 | +unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs) |
| 134 | +
|
| 135 | +print(row_id_map) |
| 136 | +print(input_act) |
| 137 | +print(permuted_inputs) |
| 138 | +print(unpermute_outputs) |
| 139 | +
|
| 140 | +# Output |
| 141 | +# tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32) |
| 142 | +# tensor([[0., 0., 0., 0.], |
| 143 | +# [1., 1., 1., 1.], |
| 144 | +# [2., 2., 2., 2.], |
| 145 | +# [3., 3., 3., 3.]], device='cuda:0') |
| 146 | +# tensor([[1., 1., 1., 1.], |
| 147 | +# [2., 2., 2., 2.], |
| 148 | +# [0., 0., 0., 0.], |
| 149 | +# [1., 1., 1., 1.], |
| 150 | +# [3., 3., 3., 3.], |
| 151 | +# [0., 0., 0., 0.], |
| 152 | +# [2., 2., 2., 2.], |
| 153 | +# [3., 3., 3., 3.]], device='cuda:0') |
| 154 | +# tensor([[0., 0., 0., 0.], |
| 155 | +# [2., 2., 2., 2.], |
| 156 | +# [4., 4., 4., 4.], |
| 157 | +# [6., 6., 6., 6.]], device='cuda:0') |
| 158 | +``` |
| 159 | +
|
0 commit comments