@@ -359,70 +359,22 @@ def forward(
359
359
tactic : int = - 1 ,
360
360
do_preparation : bool = False ,
361
361
):
362
- a , b , a_descale , b_descale , alpha , out , workspace_buffer = inputs
362
+ workspace_buffer , a , b , a_descale , b_descale , alpha , out = inputs
363
363
module .fp4_gemm .default (
364
364
a , b , a_descale , b_descale , alpha , out , workspace_buffer , tactic
365
365
)
366
366
return out
367
367
368
368
@register_custom_op (
369
- "flashinfer::cutlass_fp4_gemm " ,
369
+ "flashinfer::cutlass_fp4_gemm_runner " ,
370
370
mutates_args = ("" ),
371
371
)
372
- def cutlass_fp4_gemm (
373
- a : torch .Tensor ,
374
- b : torch .Tensor ,
375
- a_descale : torch .Tensor ,
376
- b_descale : torch .Tensor ,
377
- alpha : torch .Tensor ,
378
- out : torch .Tensor ,
379
- workspace_buffer : torch .Tensor ,
380
- ):
381
- tuner = AutoTuner .get ()
382
-
383
- a_tensor_index = 0
384
- a_scale_tensor_index = 2
385
- out_tensor_index = 5
386
-
387
- def pad_up (x , y ):
388
- return ((x + y - 1 ) // y ) * y
389
-
390
- tuning_config = TuningConfig (
391
- dynamic_tensor_specs = (
392
- DynamicTensorSpec (
393
- a_tensor_index ,
394
- 0 ,
395
- get_last_power_of_2_num_tokens_buckets ,
396
- last_positive_power_of_2 ,
397
- ),
398
- ),
399
- constraint_specs = (
400
- ConstraintSpec (
401
- a_scale_tensor_index ,
402
- 0 ,
403
- lambda shapes : pad_up (shapes [a_tensor_index ][0 ], 128 ),
404
- ),
405
- ConstraintSpec (
406
- out_tensor_index , 0 , lambda shapes : shapes [a_tensor_index ][0 ]
407
- ),
408
- ),
409
- )
410
-
411
- fp4_runner = CutlassFp4GemmRunner ()
412
-
413
- inputs = [a , b , a_descale , b_descale , alpha , out , workspace_buffer ]
414
- _ , tactic = tuner .choose_one (
415
- "cutlass_fp4_gemm" ,
416
- [fp4_runner ],
417
- tuning_config ,
418
- inputs ,
419
- )
420
-
421
- fp4_runner (inputs = inputs , tactic = tactic )
372
+ def cutlass_fp4_gemm_runner ():
373
+ return CutlassFp4GemmRunner ()
422
374
423
375
# Register the module
424
376
return SimpleNamespace (
425
- cutlass_fp4_gemm = cutlass_fp4_gemm ,
377
+ cutlass_fp4_gemm_runner = cutlass_fp4_gemm_runner ,
426
378
)
427
379
428
380
@@ -1470,25 +1422,35 @@ def mm_fp4(
1470
1422
f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
1471
1423
)
1472
1424
1473
- get_trtllm_fp4_gemm_module (). trtllm_fp4_gemm (
1425
+ fp4_gemm_sm100 (
1474
1426
a ,
1475
1427
b .T ,
1476
1428
a_descale ,
1477
1429
b_descale .T ,
1478
1430
alpha ,
1479
1431
out ,
1480
- use_8x4_sf_layout = use_8x4_sf_layout ,
1481
- workspace_buffer = workspace_buffer ,
1432
+ workspace_buffer ,
1433
+ use_8x4_sf_layout ,
1434
+ ["trtllm" ],
1482
1435
)
1483
1436
elif backend == "cutlass" :
1484
1437
# cutlass require uint8 scale when a/b is fp4 packed uint8.
1485
1438
if a .dtype == torch .uint8 and a_descale .dtype == torch .float8_e4m3fn :
1486
1439
a_descale = a_descale .view (torch .uint8 )
1487
1440
if b .dtype == torch .uint8 and b_descale .dtype == torch .float8_e4m3fn :
1488
1441
b_descale = b_descale .view (torch .uint8 )
1489
- get_gemm_sm100_module_cutlass_fp4 ().cutlass_fp4_gemm (
1490
- a , b .T , a_descale , b_descale .T , alpha , out , workspace_buffer
1442
+ fp4_gemm_sm100 (
1443
+ a ,
1444
+ b .T ,
1445
+ a_descale ,
1446
+ b_descale .T ,
1447
+ alpha ,
1448
+ out ,
1449
+ workspace_buffer ,
1450
+ use_8x4_sf_layout ,
1451
+ ["cutlass" ],
1491
1452
)
1453
+
1492
1454
return out
1493
1455
1494
1456
@@ -1782,76 +1744,92 @@ def forward(
1782
1744
return out
1783
1745
1784
1746
@register_custom_op (
1785
- "flashinfer::trtllm_fp4_gemm " ,
1747
+ "flashinfer::trtllm_fp4_gemm_runner " ,
1786
1748
mutates_args = ("" ),
1787
1749
)
1788
- def trtllm_fp4_gemm (
1789
- a : torch .Tensor ,
1790
- b : torch .Tensor ,
1791
- a_descale : torch .Tensor ,
1792
- b_descale : torch .Tensor ,
1793
- alpha : torch .Tensor ,
1794
- out : torch .Tensor ,
1795
- use_8x4_sf_layout : bool ,
1796
- workspace_buffer : torch .Tensor ,
1797
- ):
1798
- tuner = AutoTuner .get ()
1799
-
1800
- a_tensor_index = 1
1801
- a_scale_tensor_index = 3
1802
- out_tensor_index = 6
1803
-
1804
- def pad_up (x , y ):
1805
- return ((x + y - 1 ) // y ) * y
1806
-
1807
- tuning_config = TuningConfig (
1808
- dynamic_tensor_specs = (
1809
- DynamicTensorSpec (
1810
- a_tensor_index ,
1811
- 0 ,
1812
- get_last_power_of_2_num_tokens_buckets ,
1813
- last_positive_power_of_2 ,
1814
- ),
1815
- ),
1816
- constraint_specs = (
1817
- ConstraintSpec (
1818
- a_scale_tensor_index ,
1819
- 0 ,
1820
- lambda shapes : pad_up (
1821
- shapes [a_tensor_index ][0 ], 8 if use_8x4_sf_layout else 128
1822
- ),
1823
- ),
1824
- ConstraintSpec (
1825
- out_tensor_index , 0 , lambda shapes : shapes [a_tensor_index ][0 ]
1826
- ),
1827
- ),
1828
- )
1750
+ def trtllm_fp4_gemm_runner (use_8x4_sf_layout ):
1751
+ return TrtllmFp4GemmRunner (use_8x4_sf_layout )
1829
1752
1830
- fp4_runner = TrtllmFp4GemmRunner (use_8x4_sf_layout )
1753
+ # Register the module
1754
+ return SimpleNamespace (
1755
+ trtllm_fp4_gemm_runner = trtllm_fp4_gemm_runner ,
1756
+ )
1831
1757
1832
- inputs = [
1833
- workspace_buffer ,
1834
- a ,
1835
- b ,
1836
- a_descale ,
1837
- b_descale ,
1838
- alpha ,
1839
- out ,
1840
- ]
1841
- _ , tactic = tuner .choose_one (
1842
- "trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4" ,
1843
- [fp4_runner ],
1844
- tuning_config ,
1845
- inputs ,
1846
- )
1847
1758
1848
- fp4_runner (inputs = inputs , tactic = tactic )
1759
+ def fp4_gemm_sm100 (
1760
+ a : torch .Tensor ,
1761
+ b : torch .Tensor ,
1762
+ a_descale : torch .Tensor ,
1763
+ b_descale : torch .Tensor ,
1764
+ alpha : torch .Tensor ,
1765
+ out : torch .Tensor ,
1766
+ workspace_buffer : torch .Tensor ,
1767
+ use_8x4_sf_layout : bool ,
1768
+ runner_names : List [str ],
1769
+ ):
1770
+ runners = []
1771
+
1772
+ if "trtllm" in runner_names :
1773
+ runners .append (
1774
+ get_trtllm_fp4_gemm_module ().trtllm_fp4_gemm_runner (
1775
+ use_8x4_sf_layout = use_8x4_sf_layout
1776
+ )
1777
+ )
1778
+ if "cutlass" in runner_names and not use_8x4_sf_layout :
1779
+ runners .append (get_gemm_sm100_module_cutlass_fp4 ().cutlass_fp4_gemm_runner ())
1780
+ if len (runners ) == 0 :
1781
+ raise ValueError ("No runner specified" )
1782
+
1783
+ tuner = AutoTuner .get ()
1784
+
1785
+ a_tensor_index = 1
1786
+ a_scale_tensor_index = 3
1787
+ out_tensor_index = 6
1788
+
1789
+ def pad_up (x , y ):
1790
+ return ((x + y - 1 ) // y ) * y
1791
+
1792
+ tuning_config = TuningConfig (
1793
+ dynamic_tensor_specs = (
1794
+ DynamicTensorSpec (
1795
+ a_tensor_index ,
1796
+ 0 ,
1797
+ get_last_power_of_2_num_tokens_buckets ,
1798
+ last_positive_power_of_2 ,
1799
+ ),
1800
+ ),
1801
+ constraint_specs = (
1802
+ ConstraintSpec (
1803
+ a_scale_tensor_index ,
1804
+ 0 ,
1805
+ lambda shapes : pad_up (
1806
+ shapes [a_tensor_index ][0 ], 8 if use_8x4_sf_layout else 128
1807
+ ),
1808
+ ),
1809
+ ConstraintSpec (
1810
+ out_tensor_index , 0 , lambda shapes : shapes [a_tensor_index ][0 ]
1811
+ ),
1812
+ ),
1813
+ )
1849
1814
1850
- # Register the module
1851
- return SimpleNamespace (
1852
- trtllm_fp4_gemm = trtllm_fp4_gemm ,
1815
+ inputs = [
1816
+ workspace_buffer ,
1817
+ a ,
1818
+ b ,
1819
+ a_descale ,
1820
+ b_descale ,
1821
+ alpha ,
1822
+ out ,
1823
+ ]
1824
+ runner , tactic = tuner .choose_one (
1825
+ f"fp4_gemm_auto_{ '8x4' if use_8x4_sf_layout else '128x4' } " ,
1826
+ runners ,
1827
+ tuning_config ,
1828
+ inputs ,
1853
1829
)
1854
1830
1831
+ runner (inputs = inputs , tactic = tactic )
1832
+
1855
1833
1856
1834
def gemm_fp8_nt_blockscaled (
1857
1835
a : torch .Tensor ,
0 commit comments