Skip to content

Commit 5ba01fa

Browse files
xiaoxi-wangfjtimmoon10tdophung
authored
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization (#1921)
* [PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization 1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`, that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage. 2.Add tests of fused permute/pad and unpermute/unpad. Signed-off-by: xiaoxi-wangfj <[email protected]> * [PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_merging_probs Signed-off-by: xiaoxi-wangfj <[email protected]> * [PyTorch]format code Signed-off-by: xiaoxi-wangfj <[email protected]> * [Common]perf expert_idx loaded once Signed-off-by: xiaoxi-wangfj <[email protected]> * fix: pad_offsets can be None Co-authored-by: Tim Moon <[email protected]> Signed-off-by: xiaoxi-wangfj <[email protected]> * add padding + merging probs bwd support. Not tested Signed-off-by: tdophung <[email protected]> * Fix garbage initialized act grad Signed-off-by: tdophung <[email protected]> * all test passing for jax permutation + pad Signed-off-by: tdophung <[email protected]> * change tokens_per_experts APIs to num_out_tokens with conservative allocation of worst case padding for output buffer Signed-off-by: tdophung <[email protected]> * change test permutation to reduce test time Signed-off-by: tdophung <[email protected]> * triggering PR refresh Signed-off-by: tdophung <[email protected]> * format code Signed-off-by: tdophung <[email protected]> * Remove some tests cases from pytorch side. Add a separate toekn_dispatch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax Signed-off-by: tdophung <[email protected]> * format code Signed-off-by: tdophung <[email protected]> * remove chance for inefficiency in moving between CPU and GPU, remove redundant primitive using a new static bool for padding, add assert for align size Signed-off-by: tdophung <[email protected]> * fix lint in jax Signed-off-by: tdophung <[email protected]> * account for both jax newer and older than version 0.8.2. Adjusted gpu triton binding accordingly Signed-off-by: tdophung <[email protected]> * format code Signed-off-by: tdophung <[email protected]> * fix typo Signed-off-by: tdophung <[email protected]> --------- Signed-off-by: xiaoxi-wangfj <[email protected]> Signed-off-by: tdophung <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: tdophung <[email protected]>
1 parent 97a09c2 commit 5ba01fa

File tree

9 files changed

+2233
-480
lines changed

9 files changed

+2233
-480
lines changed

tests/jax/test_permutation.py

Lines changed: 565 additions & 332 deletions
Large diffs are not rendered by default.

tests/pytorch/test_permutation.py

Lines changed: 649 additions & 9 deletions
Large diffs are not rendered by default.

transformer_engine/common/triton/permutation.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def _permute_kernel(
200200
probs_ptr,
201201
scale_ptr,
202202
permuted_scale_ptr,
203+
pad_offsets_ptr,
203204
# sizes
204205
scale_hidden_dim,
205206
# strides
@@ -224,8 +225,11 @@ def _permute_kernel(
224225
hidden_size: tl.constexpr,
225226
PERMUTE_PROBS: tl.constexpr,
226227
PERMUTE_SCALE: tl.constexpr,
228+
FUSION_PAD: tl.constexpr,
227229
BLOCK_SIZE: tl.constexpr,
228230
):
231+
expert_idx = 0
232+
229233
pid_t = tl.program_id(0)
230234
pid_h = tl.program_id(1)
231235
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -246,18 +250,22 @@ def _permute_kernel(
246250
dst_row = tl.load(
247251
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
248252
).to(tl.int64)
253+
if FUSION_PAD or PERMUTE_PROBS:
254+
expert_idx = tl.load(
255+
row_id_map_ptr
256+
+ pid_t * stride_row_id_map_token
257+
+ (num_experts + idx) * stride_row_id_map_expert
258+
)
259+
if FUSION_PAD:
260+
pad_off = tl.load(pad_offsets_ptr + expert_idx)
261+
dst_row = dst_row + pad_off
249262
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
250263
if PERMUTE_SCALE:
251264
permuted_scale_off = (
252265
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
253266
)
254267
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
255268
if PERMUTE_PROBS:
256-
expert_idx = tl.load(
257-
row_id_map_ptr
258-
+ pid_t * stride_row_id_map_token
259-
+ (num_experts + idx) * stride_row_id_map_expert
260-
)
261269
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
262270
prob = tl.load(probs_ptr + prob_off)
263271
if pid_h == 0:
@@ -297,6 +305,7 @@ def _unpermute_kernel(
297305
row_id_map_ptr,
298306
merging_probs_ptr,
299307
permuted_probs_ptr,
308+
pad_offsets_ptr,
300309
# strides
301310
stride_row_id_map_token,
302311
stride_row_id_map_expert,
@@ -318,10 +327,12 @@ def _unpermute_kernel(
318327
PROBS_LOAD_WIDTH: tl.constexpr,
319328
WITH_MERGING_PROBS: tl.constexpr,
320329
PERMUTE_PROBS: tl.constexpr,
330+
FUSION_UNPAD: tl.constexpr,
321331
BLOCK_SIZE: tl.constexpr,
322332
):
323333
data_type = input_ptr.dtype.element_ty
324334
compute_type = tl.float32
335+
expert_idx = 0
325336

326337
pid_t = tl.program_id(0)
327338
pid_h = tl.program_id(1)
@@ -348,15 +359,19 @@ def _unpermute_kernel(
348359
src_row = tl.load(
349360
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
350361
).to(tl.int64)
351-
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
352-
inp = tl.load(input_ptr + input_off, mask=mask)
353-
inp = inp.to(compute_type)
354-
if WITH_MERGING_PROBS:
362+
if FUSION_UNPAD or WITH_MERGING_PROBS:
355363
expert_idx = tl.load(
356364
row_id_map_ptr
357365
+ pid_t * stride_row_id_map_token
358366
+ (num_experts + idx) * stride_row_id_map_expert
359367
)
368+
if FUSION_UNPAD:
369+
pad_off = tl.load(pad_offsets_ptr + expert_idx)
370+
src_row = src_row + pad_off
371+
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
372+
inp = tl.load(input_ptr + input_off, mask=mask)
373+
inp = inp.to(compute_type)
374+
if WITH_MERGING_PROBS:
360375
merging_prob_off = (
361376
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
362377
)
@@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
407422
fwd_input_ptr,
408423
merging_probs_ptr,
409424
row_id_map_ptr,
425+
pad_offsets_ptr,
410426
# strides
411427
stride_row_id_map_token,
412428
stride_row_id_map_expert,
@@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
427443
num_experts: tl.constexpr,
428444
hidden_size: tl.constexpr,
429445
PROBS_LOAD_WIDTH: tl.constexpr,
446+
FUSION_UNPAD: tl.constexpr,
430447
BLOCK_SIZE: tl.constexpr,
431448
):
432449
data_type = fwd_output_grad_ptr.dtype.element_ty
@@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
450467
+ pid * stride_row_id_map_token
451468
+ (num_experts + idx) * stride_row_id_map_expert
452469
)
470+
if FUSION_UNPAD:
471+
pad_off = tl.load(pad_offsets_ptr + expert_idx)
472+
dst_row = dst_row + pad_off
453473
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
454474
current_start = 0
455475
while current_start < hidden_size:

0 commit comments

Comments
 (0)