Skip to content

Commit f601746

Browse files
committed
refactor: unify autotuner for fp4 backends
Signed-off-by: Vincent Huang <[email protected]>
1 parent ade3885 commit f601746

File tree

1 file changed

+111
-121
lines changed

1 file changed

+111
-121
lines changed

flashinfer/gemm.py

Lines changed: 111 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -359,70 +359,22 @@ def forward(
359359
tactic: int = -1,
360360
do_preparation: bool = False,
361361
):
362-
a, b, a_descale, b_descale, alpha, out, workspace_buffer = inputs
362+
workspace_buffer, a, b, a_descale, b_descale, alpha, out = inputs
363363
module.fp4_gemm.default(
364364
a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic
365365
)
366366
return out
367367

368368
@register_custom_op(
369-
"flashinfer::cutlass_fp4_gemm",
369+
"flashinfer::cutlass_fp4_gemm_runner",
370370
mutates_args=(""),
371371
)
372-
def cutlass_fp4_gemm(
373-
a: torch.Tensor,
374-
b: torch.Tensor,
375-
a_descale: torch.Tensor,
376-
b_descale: torch.Tensor,
377-
alpha: torch.Tensor,
378-
out: torch.Tensor,
379-
workspace_buffer: torch.Tensor,
380-
):
381-
tuner = AutoTuner.get()
382-
383-
a_tensor_index = 0
384-
a_scale_tensor_index = 2
385-
out_tensor_index = 5
386-
387-
def pad_up(x, y):
388-
return ((x + y - 1) // y) * y
389-
390-
tuning_config = TuningConfig(
391-
dynamic_tensor_specs=(
392-
DynamicTensorSpec(
393-
a_tensor_index,
394-
0,
395-
get_last_power_of_2_num_tokens_buckets,
396-
last_positive_power_of_2,
397-
),
398-
),
399-
constraint_specs=(
400-
ConstraintSpec(
401-
a_scale_tensor_index,
402-
0,
403-
lambda shapes: pad_up(shapes[a_tensor_index][0], 128),
404-
),
405-
ConstraintSpec(
406-
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
407-
),
408-
),
409-
)
410-
411-
fp4_runner = CutlassFp4GemmRunner()
412-
413-
inputs = [a, b, a_descale, b_descale, alpha, out, workspace_buffer]
414-
_, tactic = tuner.choose_one(
415-
"cutlass_fp4_gemm",
416-
[fp4_runner],
417-
tuning_config,
418-
inputs,
419-
)
420-
421-
fp4_runner(inputs=inputs, tactic=tactic)
372+
def cutlass_fp4_gemm_runner():
373+
return CutlassFp4GemmRunner()
422374

423375
# Register the module
424376
return SimpleNamespace(
425-
cutlass_fp4_gemm=cutlass_fp4_gemm,
377+
cutlass_fp4_gemm_runner=cutlass_fp4_gemm_runner,
426378
)
427379

428380

@@ -1316,7 +1268,7 @@ def mm_fp4(
13161268
out: Optional[torch.Tensor] = None,
13171269
block_size: int = 16,
13181270
use_8x4_sf_layout: bool = False,
1319-
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1271+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn",
13201272
) -> torch.Tensor:
13211273
r"""MM FP4
13221274
@@ -1470,25 +1422,47 @@ def mm_fp4(
14701422
f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
14711423
)
14721424

1473-
get_trtllm_fp4_gemm_module().trtllm_fp4_gemm(
1425+
fp4_gemm_sm100(
14741426
a,
14751427
b.T,
14761428
a_descale,
14771429
b_descale.T,
14781430
alpha,
14791431
out,
1480-
use_8x4_sf_layout=use_8x4_sf_layout,
1481-
workspace_buffer=workspace_buffer,
1432+
workspace_buffer,
1433+
use_8x4_sf_layout,
1434+
["trtllm"],
14821435
)
14831436
elif backend == "cutlass":
14841437
# cutlass require uint8 scale when a/b is fp4 packed uint8.
14851438
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
14861439
a_descale = a_descale.view(torch.uint8)
14871440
if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn:
14881441
b_descale = b_descale.view(torch.uint8)
1489-
get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm(
1490-
a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer
1442+
fp4_gemm_sm100(
1443+
a,
1444+
b.T,
1445+
a_descale,
1446+
b_descale.T,
1447+
alpha,
1448+
out,
1449+
workspace_buffer,
1450+
use_8x4_sf_layout,
1451+
["cutlass"],
14911452
)
1453+
elif backend == "auto":
1454+
fp4_gemm_sm100(
1455+
a,
1456+
b.T,
1457+
a_descale,
1458+
b_descale.T,
1459+
alpha,
1460+
out,
1461+
workspace_buffer,
1462+
use_8x4_sf_layout,
1463+
["trtllm", "cutlass"],
1464+
)
1465+
14921466
return out
14931467

14941468

@@ -1782,76 +1756,92 @@ def forward(
17821756
return out
17831757

17841758
@register_custom_op(
1785-
"flashinfer::trtllm_fp4_gemm",
1759+
"flashinfer::trtllm_fp4_gemm_runner",
17861760
mutates_args=(""),
17871761
)
1788-
def trtllm_fp4_gemm(
1789-
a: torch.Tensor,
1790-
b: torch.Tensor,
1791-
a_descale: torch.Tensor,
1792-
b_descale: torch.Tensor,
1793-
alpha: torch.Tensor,
1794-
out: torch.Tensor,
1795-
use_8x4_sf_layout: bool,
1796-
workspace_buffer: torch.Tensor,
1797-
):
1798-
tuner = AutoTuner.get()
1799-
1800-
a_tensor_index = 1
1801-
a_scale_tensor_index = 3
1802-
out_tensor_index = 6
1803-
1804-
def pad_up(x, y):
1805-
return ((x + y - 1) // y) * y
1806-
1807-
tuning_config = TuningConfig(
1808-
dynamic_tensor_specs=(
1809-
DynamicTensorSpec(
1810-
a_tensor_index,
1811-
0,
1812-
get_last_power_of_2_num_tokens_buckets,
1813-
last_positive_power_of_2,
1814-
),
1815-
),
1816-
constraint_specs=(
1817-
ConstraintSpec(
1818-
a_scale_tensor_index,
1819-
0,
1820-
lambda shapes: pad_up(
1821-
shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128
1822-
),
1823-
),
1824-
ConstraintSpec(
1825-
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
1826-
),
1827-
),
1828-
)
1762+
def trtllm_fp4_gemm_runner(use_8x4_sf_layout):
1763+
return TrtllmFp4GemmRunner(use_8x4_sf_layout)
18291764

1830-
fp4_runner = TrtllmFp4GemmRunner(use_8x4_sf_layout)
1765+
# Register the module
1766+
return SimpleNamespace(
1767+
trtllm_fp4_gemm_runner=trtllm_fp4_gemm_runner,
1768+
)
18311769

1832-
inputs = [
1833-
workspace_buffer,
1834-
a,
1835-
b,
1836-
a_descale,
1837-
b_descale,
1838-
alpha,
1839-
out,
1840-
]
1841-
_, tactic = tuner.choose_one(
1842-
"trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4",
1843-
[fp4_runner],
1844-
tuning_config,
1845-
inputs,
1846-
)
18471770

1848-
fp4_runner(inputs=inputs, tactic=tactic)
1771+
def fp4_gemm_sm100(
1772+
a: torch.Tensor,
1773+
b: torch.Tensor,
1774+
a_descale: torch.Tensor,
1775+
b_descale: torch.Tensor,
1776+
alpha: torch.Tensor,
1777+
out: torch.Tensor,
1778+
workspace_buffer: torch.Tensor,
1779+
use_8x4_sf_layout: bool,
1780+
runner_names: List[str],
1781+
):
1782+
runners = []
1783+
1784+
if "trtllm" in runner_names:
1785+
runners.append(
1786+
get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(
1787+
use_8x4_sf_layout=use_8x4_sf_layout
1788+
)
1789+
)
1790+
if "cutlass" in runner_names and not use_8x4_sf_layout:
1791+
runners.append(get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner())
1792+
if len(runners) == 0:
1793+
raise ValueError("No runner specified")
1794+
1795+
tuner = AutoTuner.get()
1796+
1797+
a_tensor_index = 1
1798+
a_scale_tensor_index = 3
1799+
out_tensor_index = 6
1800+
1801+
def pad_up(x, y):
1802+
return ((x + y - 1) // y) * y
1803+
1804+
tuning_config = TuningConfig(
1805+
dynamic_tensor_specs=(
1806+
DynamicTensorSpec(
1807+
a_tensor_index,
1808+
0,
1809+
get_last_power_of_2_num_tokens_buckets,
1810+
last_positive_power_of_2,
1811+
),
1812+
),
1813+
constraint_specs=(
1814+
ConstraintSpec(
1815+
a_scale_tensor_index,
1816+
0,
1817+
lambda shapes: pad_up(
1818+
shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128
1819+
),
1820+
),
1821+
ConstraintSpec(
1822+
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
1823+
),
1824+
),
1825+
)
18491826

1850-
# Register the module
1851-
return SimpleNamespace(
1852-
trtllm_fp4_gemm=trtllm_fp4_gemm,
1827+
inputs = [
1828+
workspace_buffer,
1829+
a,
1830+
b,
1831+
a_descale,
1832+
b_descale,
1833+
alpha,
1834+
out,
1835+
]
1836+
runner, tactic = tuner.choose_one(
1837+
"fp4_gemm_auto_128x4",
1838+
runners,
1839+
tuning_config,
1840+
inputs,
18531841
)
18541842

1843+
runner(inputs=inputs, tactic=tactic)
1844+
18551845

18561846
def gemm_fp8_nt_blockscaled(
18571847
a: torch.Tensor,

0 commit comments

Comments
 (0)