Commit 5ba01fa
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization (#1921)
* [PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization
1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`,
that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage.
2.Add tests of fused permute/pad and unpermute/unpad.
Signed-off-by: xiaoxi-wangfj <[email protected]>
* [PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_merging_probs
Signed-off-by: xiaoxi-wangfj <[email protected]>
* [PyTorch]format code
Signed-off-by: xiaoxi-wangfj <[email protected]>
* [Common]perf expert_idx loaded once
Signed-off-by: xiaoxi-wangfj <[email protected]>
* fix: pad_offsets can be None
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: xiaoxi-wangfj <[email protected]>
* add padding + merging probs bwd support. Not tested
Signed-off-by: tdophung <[email protected]>
* Fix garbage initialized act grad
Signed-off-by: tdophung <[email protected]>
* all test passing for jax permutation + pad
Signed-off-by: tdophung <[email protected]>
* change tokens_per_experts APIs to num_out_tokens with conservative allocation of worst case padding for output buffer
Signed-off-by: tdophung <[email protected]>
* change test permutation to reduce test time
Signed-off-by: tdophung <[email protected]>
* triggering PR refresh
Signed-off-by: tdophung <[email protected]>
* format code
Signed-off-by: tdophung <[email protected]>
* Remove some tests cases from pytorch side. Add a separate toekn_dispatch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax
Signed-off-by: tdophung <[email protected]>
* format code
Signed-off-by: tdophung <[email protected]>
* remove chance for inefficiency in moving between CPU and GPU, remove redundant primitive using a new static bool for padding, add assert for align size
Signed-off-by: tdophung <[email protected]>
* fix lint in jax
Signed-off-by: tdophung <[email protected]>
* account for both jax newer and older than version 0.8.2. Adjusted gpu triton binding accordingly
Signed-off-by: tdophung <[email protected]>
* format code
Signed-off-by: tdophung <[email protected]>
* fix typo
Signed-off-by: tdophung <[email protected]>
---------
Signed-off-by: xiaoxi-wangfj <[email protected]>
Signed-off-by: tdophung <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: tdophung <[email protected]>1 parent 97a09c2 commit 5ba01fa
File tree
9 files changed
+2233
-480
lines changed- tests
- jax
- pytorch
- transformer_engine
- common/triton
- jax
- triton_extensions
- pytorch
- triton
9 files changed
+2233
-480
lines changedLarge diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
200 | 200 | | |
201 | 201 | | |
202 | 202 | | |
| 203 | + | |
203 | 204 | | |
204 | 205 | | |
205 | 206 | | |
| |||
224 | 225 | | |
225 | 226 | | |
226 | 227 | | |
| 228 | + | |
227 | 229 | | |
228 | 230 | | |
| 231 | + | |
| 232 | + | |
229 | 233 | | |
230 | 234 | | |
231 | 235 | | |
| |||
246 | 250 | | |
247 | 251 | | |
248 | 252 | | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
249 | 262 | | |
250 | 263 | | |
251 | 264 | | |
252 | 265 | | |
253 | 266 | | |
254 | 267 | | |
255 | 268 | | |
256 | | - | |
257 | | - | |
258 | | - | |
259 | | - | |
260 | | - | |
261 | 269 | | |
262 | 270 | | |
263 | 271 | | |
| |||
297 | 305 | | |
298 | 306 | | |
299 | 307 | | |
| 308 | + | |
300 | 309 | | |
301 | 310 | | |
302 | 311 | | |
| |||
318 | 327 | | |
319 | 328 | | |
320 | 329 | | |
| 330 | + | |
321 | 331 | | |
322 | 332 | | |
323 | 333 | | |
324 | 334 | | |
| 335 | + | |
325 | 336 | | |
326 | 337 | | |
327 | 338 | | |
| |||
348 | 359 | | |
349 | 360 | | |
350 | 361 | | |
351 | | - | |
352 | | - | |
353 | | - | |
354 | | - | |
| 362 | + | |
355 | 363 | | |
356 | 364 | | |
357 | 365 | | |
358 | 366 | | |
359 | 367 | | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
360 | 375 | | |
361 | 376 | | |
362 | 377 | | |
| |||
407 | 422 | | |
408 | 423 | | |
409 | 424 | | |
| 425 | + | |
410 | 426 | | |
411 | 427 | | |
412 | 428 | | |
| |||
427 | 443 | | |
428 | 444 | | |
429 | 445 | | |
| 446 | + | |
430 | 447 | | |
431 | 448 | | |
432 | 449 | | |
| |||
450 | 467 | | |
451 | 468 | | |
452 | 469 | | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
453 | 473 | | |
454 | 474 | | |
455 | 475 | | |
| |||
0 commit comments