-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Which component has the problem?
CuTe DSL
Bug Report
CUDA Version: 12.9
CuTe DSL Version: 4.3.0.dev0
In hopper/dense_gemm_persistent.py, I tried to add another warp group but the following errors appeared:
error: unknown: NVPTX compiler invocation failed, error log: ptxas fatal : (C7602) Insufficient registers (128) to compile instruction at line 1582 in function kernel_cutlass_kernel___main__HopperWgmmaGemmPersistentKernel_object_at__CopyAtom_ThrID10_TVLayoutSrc1819201_TVLayoutDst1819201_Valuetypef16_tensor000odiv16111012_CopyAtom_ThrID10_TVLayou_0. Try to compile with register target of 154 or higher.
The errors suggest that each thread in the MMA warp group has only 128 registers, but I have used setmaxnreg to configure the number of register files for each warp group as the following:
if warp_group_idx == 0:
cute.arch.warpgroup_reg_dealloc(24)
if warp_group_idx == 1:
cute.arch.warpgroup_reg_dealloc(24)
if warp_group_idx in (2, 3): # Two warp groups: 2 & 3
cute.arch.warpgroup_reg_alloc(224)I have used 63488 register files in total. It seemed that CuTe DSL directly uses the average number of register files when compiling instead of considering setmaxnreg hints.
It doesn't seem to be a coincidence that in the above errors CuTe DSL said Insufficient registers (128), which happened to be 65536//(128*4)==128, and it said Try to compile with register target of 154 or higher, which happened to be satisfied when only 3 WGs is used, to wit, 65536//(128*3) == 170.67 > 154. That's why I think CuTe DSL directly uses the average number of register files per thread when compiling, discarding setmaxnreg hints.
A minimal reproduction can be run like this:
python3 dense_gemm_persistent_four_wgs.py --mnkl 8192,4096,4096,1 --tile_shape_mn 128,256 --cluster_shape_mn 2,1 --a_dtype Float16 --b_dtype Float16 --c_dtype Float16 --acc_dtype Float32 --a_major k --b_major n --c_major n --warmup_iterations 0 --iterations 1 2>&1 | tee dense_gemm_4_wgs.log