File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed
python/triton_kernels/triton_kernels/matmul_ogs_details Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change 33from dataclasses import dataclass
44
55import triton
6+ from triton_kernels import target_info
67from triton_kernels .target_info import get_cdna_version
78from triton_kernels .tensor import FP4
89import torch
910from .opt_flags_details import opt_flags_amd , opt_flags_nvidia , opt_flags_intel
10- from triton_kernels .tensor import bitwidth
11+ from triton_kernels .tensor import bitwidth , get_layout
1112
1213
1314@dataclass
@@ -297,8 +298,12 @@ def make_default_opt_flags_nvidia(
297298 n_sms = torch .cuda .get_device_properties (0 ).multi_processor_count
298299 tiles_per_sm = grid_size_tma / n_sms
299300 supports_persistent = can_use_persistent_tma and (arch is None or int (arch [2 :- 1 ]) >= 9 )
301+ requires_persistent = (get_layout (precision_config .act_scale ) is not None or get_layout (precision_config .weight_scale ) is not None ) and target_info .has_native_mxfp ()
300302 if constraints .get ("is_persistent" , None ) is not None :
301303 is_persistent = constraints ["is_persistent" ]
304+ elif requires_persistent :
305+ assert supports_persistent , "persistent kernel required but not supported"
306+ is_persistent = True
302307 else :
303308 has_simple_epilogue = precision_config .max_num_imprecise_acc is None
304309 is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype .itemsize <= 1 ) and out_dtype .itemsize < 4
You can’t perform that action at this time.
0 commit comments