Skip to content

Commit 679eb99

Browse files
committed
refactor: unify autotuner for fp4 backends
Signed-off-by: Vincent Huang <[email protected]>
1 parent ade3885 commit 679eb99

File tree

1 file changed

+98
-120
lines changed

1 file changed

+98
-120
lines changed

flashinfer/gemm.py

Lines changed: 98 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -359,70 +359,22 @@ def forward(
359359
tactic: int = -1,
360360
do_preparation: bool = False,
361361
):
362-
a, b, a_descale, b_descale, alpha, out, workspace_buffer = inputs
362+
workspace_buffer, a, b, a_descale, b_descale, alpha, out = inputs
363363
module.fp4_gemm.default(
364364
a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic
365365
)
366366
return out
367367

368368
@register_custom_op(
369-
"flashinfer::cutlass_fp4_gemm",
369+
"flashinfer::cutlass_fp4_gemm_runner",
370370
mutates_args=(""),
371371
)
372-
def cutlass_fp4_gemm(
373-
a: torch.Tensor,
374-
b: torch.Tensor,
375-
a_descale: torch.Tensor,
376-
b_descale: torch.Tensor,
377-
alpha: torch.Tensor,
378-
out: torch.Tensor,
379-
workspace_buffer: torch.Tensor,
380-
):
381-
tuner = AutoTuner.get()
382-
383-
a_tensor_index = 0
384-
a_scale_tensor_index = 2
385-
out_tensor_index = 5
386-
387-
def pad_up(x, y):
388-
return ((x + y - 1) // y) * y
389-
390-
tuning_config = TuningConfig(
391-
dynamic_tensor_specs=(
392-
DynamicTensorSpec(
393-
a_tensor_index,
394-
0,
395-
get_last_power_of_2_num_tokens_buckets,
396-
last_positive_power_of_2,
397-
),
398-
),
399-
constraint_specs=(
400-
ConstraintSpec(
401-
a_scale_tensor_index,
402-
0,
403-
lambda shapes: pad_up(shapes[a_tensor_index][0], 128),
404-
),
405-
ConstraintSpec(
406-
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
407-
),
408-
),
409-
)
410-
411-
fp4_runner = CutlassFp4GemmRunner()
412-
413-
inputs = [a, b, a_descale, b_descale, alpha, out, workspace_buffer]
414-
_, tactic = tuner.choose_one(
415-
"cutlass_fp4_gemm",
416-
[fp4_runner],
417-
tuning_config,
418-
inputs,
419-
)
420-
421-
fp4_runner(inputs=inputs, tactic=tactic)
372+
def cutlass_fp4_gemm_runner():
373+
return CutlassFp4GemmRunner()
422374

423375
# Register the module
424376
return SimpleNamespace(
425-
cutlass_fp4_gemm=cutlass_fp4_gemm,
377+
cutlass_fp4_gemm_runner=cutlass_fp4_gemm_runner,
426378
)
427379

428380

@@ -1470,25 +1422,35 @@ def mm_fp4(
14701422
f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
14711423
)
14721424

1473-
get_trtllm_fp4_gemm_module().trtllm_fp4_gemm(
1425+
fp4_gemm_sm100(
14741426
a,
14751427
b.T,
14761428
a_descale,
14771429
b_descale.T,
14781430
alpha,
14791431
out,
1480-
use_8x4_sf_layout=use_8x4_sf_layout,
1481-
workspace_buffer=workspace_buffer,
1432+
workspace_buffer,
1433+
use_8x4_sf_layout,
1434+
["trtllm"],
14821435
)
14831436
elif backend == "cutlass":
14841437
# cutlass require uint8 scale when a/b is fp4 packed uint8.
14851438
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
14861439
a_descale = a_descale.view(torch.uint8)
14871440
if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn:
14881441
b_descale = b_descale.view(torch.uint8)
1489-
get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm(
1490-
a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer
1442+
fp4_gemm_sm100(
1443+
a,
1444+
b.T,
1445+
a_descale,
1446+
b_descale.T,
1447+
alpha,
1448+
out,
1449+
workspace_buffer,
1450+
use_8x4_sf_layout,
1451+
["cutlass"],
14911452
)
1453+
14921454
return out
14931455

14941456

@@ -1782,76 +1744,92 @@ def forward(
17821744
return out
17831745

17841746
@register_custom_op(
1785-
"flashinfer::trtllm_fp4_gemm",
1747+
"flashinfer::trtllm_fp4_gemm_runner",
17861748
mutates_args=(""),
17871749
)
1788-
def trtllm_fp4_gemm(
1789-
a: torch.Tensor,
1790-
b: torch.Tensor,
1791-
a_descale: torch.Tensor,
1792-
b_descale: torch.Tensor,
1793-
alpha: torch.Tensor,
1794-
out: torch.Tensor,
1795-
use_8x4_sf_layout: bool,
1796-
workspace_buffer: torch.Tensor,
1797-
):
1798-
tuner = AutoTuner.get()
1799-
1800-
a_tensor_index = 1
1801-
a_scale_tensor_index = 3
1802-
out_tensor_index = 6
1803-
1804-
def pad_up(x, y):
1805-
return ((x + y - 1) // y) * y
1806-
1807-
tuning_config = TuningConfig(
1808-
dynamic_tensor_specs=(
1809-
DynamicTensorSpec(
1810-
a_tensor_index,
1811-
0,
1812-
get_last_power_of_2_num_tokens_buckets,
1813-
last_positive_power_of_2,
1814-
),
1815-
),
1816-
constraint_specs=(
1817-
ConstraintSpec(
1818-
a_scale_tensor_index,
1819-
0,
1820-
lambda shapes: pad_up(
1821-
shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128
1822-
),
1823-
),
1824-
ConstraintSpec(
1825-
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
1826-
),
1827-
),
1828-
)
1750+
def trtllm_fp4_gemm_runner(use_8x4_sf_layout):
1751+
return TrtllmFp4GemmRunner(use_8x4_sf_layout)
18291752

1830-
fp4_runner = TrtllmFp4GemmRunner(use_8x4_sf_layout)
1753+
# Register the module
1754+
return SimpleNamespace(
1755+
trtllm_fp4_gemm_runner=trtllm_fp4_gemm_runner,
1756+
)
18311757

1832-
inputs = [
1833-
workspace_buffer,
1834-
a,
1835-
b,
1836-
a_descale,
1837-
b_descale,
1838-
alpha,
1839-
out,
1840-
]
1841-
_, tactic = tuner.choose_one(
1842-
"trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4",
1843-
[fp4_runner],
1844-
tuning_config,
1845-
inputs,
1846-
)
18471758

1848-
fp4_runner(inputs=inputs, tactic=tactic)
1759+
def fp4_gemm_sm100(
1760+
a: torch.Tensor,
1761+
b: torch.Tensor,
1762+
a_descale: torch.Tensor,
1763+
b_descale: torch.Tensor,
1764+
alpha: torch.Tensor,
1765+
out: torch.Tensor,
1766+
workspace_buffer: torch.Tensor,
1767+
use_8x4_sf_layout: bool,
1768+
runner_names: List[str],
1769+
):
1770+
runners = []
1771+
1772+
if "trtllm" in runner_names:
1773+
runners.append(
1774+
get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(
1775+
use_8x4_sf_layout=use_8x4_sf_layout
1776+
)
1777+
)
1778+
if "cutlass" in runner_names and not use_8x4_sf_layout:
1779+
runners.append(get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner())
1780+
if len(runners) == 0:
1781+
raise ValueError("No runner specified")
1782+
1783+
tuner = AutoTuner.get()
1784+
1785+
a_tensor_index = 1
1786+
a_scale_tensor_index = 3
1787+
out_tensor_index = 6
1788+
1789+
def pad_up(x, y):
1790+
return ((x + y - 1) // y) * y
1791+
1792+
tuning_config = TuningConfig(
1793+
dynamic_tensor_specs=(
1794+
DynamicTensorSpec(
1795+
a_tensor_index,
1796+
0,
1797+
get_last_power_of_2_num_tokens_buckets,
1798+
last_positive_power_of_2,
1799+
),
1800+
),
1801+
constraint_specs=(
1802+
ConstraintSpec(
1803+
a_scale_tensor_index,
1804+
0,
1805+
lambda shapes: pad_up(
1806+
shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128
1807+
),
1808+
),
1809+
ConstraintSpec(
1810+
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
1811+
),
1812+
),
1813+
)
18491814

1850-
# Register the module
1851-
return SimpleNamespace(
1852-
trtllm_fp4_gemm=trtllm_fp4_gemm,
1815+
inputs = [
1816+
workspace_buffer,
1817+
a,
1818+
b,
1819+
a_descale,
1820+
b_descale,
1821+
alpha,
1822+
out,
1823+
]
1824+
runner, tactic = tuner.choose_one(
1825+
f"fp4_gemm_auto_{'8x4' if use_8x4_sf_layout else '128x4'}",
1826+
runners,
1827+
tuning_config,
1828+
inputs,
18531829
)
18541830

1831+
runner(inputs=inputs, tactic=tactic)
1832+
18551833

18561834
def gemm_fp8_nt_blockscaled(
18571835
a: torch.Tensor,

0 commit comments

Comments
 (0)