Skip to content

Commit 3f5bafb

Browse files
authored
[Cutlass profiler] Fix SM100 FP8 nosmem epilogue shape_div “Divisibility Condition” for non‑multiple‑of‑64 N tiles (#2946)
* . * . * . * . * . * . * .
1 parent 1e6da09 commit 3f5bafb

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

python/cutlass_library/generator.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7488,6 +7488,17 @@ def GenerateSM100_TensorOp_fp8_UMMA_alignx_gemm(manifest, cuda_version, gemm_kin
74887488
TileSchedulerType.Default
74897489
]
74907490

7491+
# Some SM100 NoSmem epilogue instantiations rely on CUTE's shape_div, which enforces a compile-time
7492+
# divisibility condition between CTA N and the epilogue N tile. Keep this conservative and scoped:
7493+
# only apply the divisibility filter for selected common (c_type, d_type) pairs.
7494+
#
7495+
# Map (c_type, d_type) -> required divisor for CTA N when CTA N > divisor.
7496+
# (If CTA N <= divisor, the epilogue N tile equals CTA N and is always divisible.)
7497+
_sm100_epilogue_tile_n_divisibility = {
7498+
(DataType.void, DataType.f16): 64,
7499+
(DataType.void, DataType.bf16): 64,
7500+
}
7501+
74917502
# 1xSM MMA kernels
74927503
for math_inst in math_instructions_1sm:
74937504
tile_descriptions = []
@@ -7607,7 +7618,24 @@ def GenerateSM100_TensorOp_fp8_UMMA_alignx_gemm(manifest, cuda_version, gemm_kin
76077618

76087619
kernel_schedule = KernelScheduleType.WarpSpecialized1SmSm100
76097620
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized1Sm
7610-
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
7621+
7622+
# SM100 NoSmem epilogue uses EpilogueTileAuto with N-tile = min(64, cta_n).
7623+
# CUTE's shape_div then requires a compile-time divisibility condition between cta_n and 64.
7624+
# Only instantiate kernels where cta_n <= 64 or cta_n is an exact multiple of 64 to avoid
7625+
# violating that "Divisibility Condition" static_assert.
7626+
filtered_tile_descriptions = []
7627+
for tile_description in tile_descriptions:
7628+
div_n = _sm100_epilogue_tile_n_divisibility.get((data_type["c_type"], data_type["d_type"]))
7629+
if div_n is not None:
7630+
cta_n = tile_description.threadblock_shape[1]
7631+
if cta_n > div_n and (cta_n % div_n != 0):
7632+
continue
7633+
filtered_tile_descriptions.append(tile_description)
7634+
7635+
if not filtered_tile_descriptions:
7636+
continue
7637+
7638+
CreateGemmUniversal3xOperator(manifest, layouts, filtered_tile_descriptions, data_type,
76117639
[[kernel_schedule, epi_schedule]],
76127640
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
76137641

0 commit comments

Comments
 (0)