diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 7f799fa7f..457173eef 100755 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -468,88 +468,37 @@ def cutlass_fp8_gemm( def get_gemm_sm100_module_cutlass_fp4(): module = gen_gemm_sm100_module_cutlass_fp4().build_and_load() - class CutlassFp4GemmRunner(TunableRunner): - def __init__(self): - self._fp4_gemm_runner = module.fp4_gemm - - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - return list(range(module.fp4_gemm_tactic_num())) - - def forward( - self, - inputs: List[torch.Tensor], - *, - tactic: int = -1, - do_preparation: bool = False, - ): - a, b, a_descale, b_descale, alpha, out, workspace_buffer = inputs - module.fp4_gemm.default( - a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic - ) - return out - @register_custom_op( - "flashinfer::cutlass_fp4_gemm", + "flashinfer::cutlass_fp4_gemm_runner", mutates_args=(""), ) - def cutlass_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - - a_tensor_index = 0 - a_scale_tensor_index = 2 - out_tensor_index = 5 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - a_tensor_index, - 0, - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up(shapes[a_tensor_index][0], 128), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), - ) - - fp4_runner = CutlassFp4GemmRunner() - - inputs = [a, b, a_descale, b_descale, alpha, out, workspace_buffer] - _, tactic = tuner.choose_one( - "cutlass_fp4_gemm", - [fp4_runner], - tuning_config, - inputs, - ) + def cutlass_fp4_gemm_runner(): + class CutlassFp4GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.fp4_gemm_tactic_num())) + + def forward( + self, + inputs: List[torch.Tensor], + *, + tactic: int = -1, + do_preparation: bool = False, + ): + workspace_buffer, a, b, a_descale, b_descale, alpha, out = inputs + module.fp4_gemm.default( + a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic + ) + return out - fp4_runner(inputs=inputs, tactic=tactic) + return CutlassFp4GemmRunner() # Register the module return SimpleNamespace( - cutlass_fp4_gemm=cutlass_fp4_gemm, + cutlass_fp4_gemm_runner=cutlass_fp4_gemm_runner, ) @@ -1618,15 +1567,16 @@ def mm_fp4( f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations." ) - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( + fp4_gemm_sm100( a, b.T, a_descale, b_descale.T, alpha, out, - use_8x4_sf_layout=use_8x4_sf_layout, - workspace_buffer=workspace_buffer, + workspace_buffer, + use_8x4_sf_layout, + ["trtllm"], ) elif backend == "cutlass": # cutlass require uint8 scale when a/b is fp4 packed uint8. @@ -1634,9 +1584,18 @@ def mm_fp4( a_descale = a_descale.view(torch.uint8) if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: b_descale = b_descale.view(torch.uint8) - get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm( - a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer + fp4_gemm_sm100( + a, + b.T, + a_descale, + b_descale.T, + alpha, + out, + workspace_buffer, + use_8x4_sf_layout, + ["cutlass"], ) + return out @@ -1872,141 +1831,156 @@ def get_trtllm_fp4_gemm_module(): op = mod.build_and_load() setup_cubin_loader(mod.get_library_path()) - class TrtllmFp4GemmRunner(TunableRunner): - def __init__(self, use_8x4_sf_layout: bool = True): - self._fp4_gemm_runner = op.trtllm_gemm - self._use_8x4_sf_layout = use_8x4_sf_layout + @register_custom_op( + "flashinfer::trtllm_fp4_gemm_runner", + mutates_args=(""), + ) + def trtllm_fp4_gemm_runner(use_8x4_sf_layout): + class TrtllmFp4GemmRunner(TunableRunner): + def __init__(self, use_8x4_sf_layout: bool = True): + self._use_8x4_sf_layout = use_8x4_sf_layout + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a_tensor_index = 1 + b_tensor_index = 2 + + a = profile.get_opt_shapes()[a_tensor_index] + b = profile.get_opt_shapes()[b_tensor_index] + m = a[0] + n = b[0] + k = a[1] * 2 + ( + workspace_buffer, + a, + b, + a_descale, + b_descale, + alpha, + out, + ) = inputs + type_e2m1 = 0 + type_bf16 = 2 + return list( + op.trtllm_gemm_tactics( + m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + ) + ) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - a_tensor_index = 1 - b_tensor_index = 2 - - a = profile.get_opt_shapes()[a_tensor_index] - b = profile.get_opt_shapes()[b_tensor_index] - m = a[0] - n = b[0] - k = a[1] * 2 - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - type_e2m1 = 0 - type_bf16 = 2 - return list( - op.trtllm_gemm_tactics( - m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + def forward( + self, + inputs: List[torch.Tensor], + *, + tactic: int = -1, + do_preparation: bool = False, + ): + ( + workspace_buffer, + a, + b, + a_descale, + b_descale, + alpha, + out, + ) = inputs + op.trtllm_gemm.default( + workspace_buffer, + a, + b, + a_descale, + b_descale, + alpha, + out, + self._use_8x4_sf_layout, + tactic, ) - ) + return out - def forward( - self, - inputs: List[torch.Tensor], - *, - tactic: int = -1, - do_preparation: bool = False, - ): - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - op.trtllm_gemm.default( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - self._use_8x4_sf_layout, - tactic, - ) - return out + return TrtllmFp4GemmRunner(use_8x4_sf_layout) - @register_custom_op( - "flashinfer::trtllm_fp4_gemm", - mutates_args=(""), + # Register the module + return SimpleNamespace( + trtllm_fp4_gemm_runner=trtllm_fp4_gemm_runner, ) - def trtllm_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - use_8x4_sf_layout: bool, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - a_tensor_index = 1 - a_scale_tensor_index = 3 - out_tensor_index = 6 - def pad_up(x, y): - return ((x + y - 1) // y) * y +def fp4_gemm_sm100( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: torch.Tensor, + out: torch.Tensor, + workspace_buffer: torch.Tensor, + use_8x4_sf_layout: bool, + runner_names: List[str], +): + runners = [] - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - a_tensor_index, - 0, - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), + if "trtllm" in runner_names: + runners.append( + get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner( + use_8x4_sf_layout=use_8x4_sf_layout + ) + ) + if "cutlass" in runner_names and not use_8x4_sf_layout: + runners.append(get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner()) + if len(runners) == 0: + raise ValueError("No runner specified") + + tuner = AutoTuner.get() + + a_tensor_index = 1 + a_scale_tensor_index = 3 + out_tensor_index = 6 + + def pad_up(x, y): + return ((x + y - 1) // y) * y + + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + a_tensor_index, + 0, + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up( - shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 - ), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] + ), + constraint_specs=( + ConstraintSpec( + a_scale_tensor_index, + 0, + lambda shapes: pad_up( + shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 ), ), - ) - - fp4_runner = TrtllmFp4GemmRunner(use_8x4_sf_layout) - - inputs = [ - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ] - _, tactic = tuner.choose_one( - "trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4", - [fp4_runner], - tuning_config, - inputs, - ) - - fp4_runner(inputs=inputs, tactic=tactic) + ConstraintSpec( + out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] + ), + ), + ) - # Register the module - return SimpleNamespace( - trtllm_fp4_gemm=trtllm_fp4_gemm, + inputs = [ + workspace_buffer, + a, + b, + a_descale, + b_descale, + alpha, + out, + ] + runner, tactic = tuner.choose_one( + f"fp4_gemm_auto_{'8x4' if use_8x4_sf_layout else '128x4'}", + runners, + tuning_config, + inputs, ) + runner(inputs=inputs, tactic=tactic) + def gemm_fp8_nt_blockscaled( a: torch.Tensor,