@@ -320,6 +320,8 @@ def get_dense_gemm_approximate_cta_nums(
320320 Sm100BlockScaledPersistentDenseGemmKernel
321321 from ..cute_dsl_kernels .blackwell .top_k .filtered_top_k_decode_varlen import \
322322 FilteredTopKKernelVarlenDecode
323+ from ..cute_dsl_kernels .blackwell .dense_gemm_persistent import \
324+ PersistentDenseGemmKernel
323325 from ..cute_dsl_kernels .blackwell .utils import make_ptr
324326
325327 class CuteDSLNVFP4BlackwellRunner (TunableRunner ):
@@ -3739,3 +3741,273 @@ def warmup_cute_dsl_topk_kernels(
37393741
37403742 logger .info (f"Warmup: pre-compiled { count } CuTE DSL top-k kernels "
37413743 f"(dtype={ dtype } , top_k={ top_k } , next_n={ next_n } )" )
3744+
3745+ # ======================================================================
3746+ # BF16 Dense Persistent BMM (CuTe DSL) for Blackwell
3747+ # ======================================================================
3748+
3749+ class CuteDSLBf16BlackwellBmmRunner (TunableRunner ):
3750+ kernel_class = PersistentDenseGemmKernel
3751+ kernel_cache = dict ()
3752+
3753+ tuning_config = TuningConfig (
3754+ dynamic_tensor_specs = (DynamicTensorSpec (
3755+ 0 , 1 , get_last_power_of_2_num_tokens_buckets ,
3756+ last_positive_power_of_2 ), ),
3757+ )
3758+
3759+ def __init__ (self , use_tvm_ffi : bool = True ):
3760+ super ().__init__ ()
3761+ self .use_tvm_ffi = use_tvm_ffi
3762+
3763+ def get_valid_tactics (
3764+ self ,
3765+ inputs : List [torch .Tensor ],
3766+ profile : OptimizationProfile ,
3767+ ** kwargs ,
3768+ ) -> List [int ]:
3769+
3770+ if not is_sm_100f ():
3771+ logger .debug (
3772+ f"CuteDSL: SM version { get_sm_version ()} is not supported. "
3773+ f"CuteDSL BF16 BMM only supports SM 100 family. Skipping all tactics."
3774+ )
3775+ return []
3776+ # [b, m, k]
3777+ batch_size , m , k = inputs [0 ].shape [0 ], inputs [0 ].shape [
3778+ 1 ], inputs [0 ].shape [2 ]
3779+ # [b, n, k]
3780+ n = inputs [1 ].shape [1 ]
3781+ # m,k
3782+ a_major = "k"
3783+ # n, k
3784+ b_major = "k"
3785+ # m, n
3786+ c_major = "n"
3787+
3788+ use_2cta_instrs_candi = [False , True ]
3789+ mma_tiler_mn_candi = [(64 , 128 ), (128 , 128 ), (256 , 128 )]
3790+ cluster_shape_mn_candi = [
3791+ (1 , 1 ),
3792+ (1 , 2 ),
3793+ (1 , 4 ),
3794+ (2 , 1 ),
3795+ (2 , 2 ),
3796+ (2 , 4 ),
3797+ (4 , 1 ),
3798+ (4 , 2 ),
3799+ (4 , 4 ),
3800+ ]
3801+ return [
3802+ (use_2cta_instrs , mma_tiler_mn , cluster_shape_mn )
3803+ for use_2cta_instrs in use_2cta_instrs_candi
3804+ for mma_tiler_mn in mma_tiler_mn_candi
3805+ for cluster_shape_mn in cluster_shape_mn_candi
3806+ if self .__class__ .kernel_class .can_implement (
3807+ cutlass .BFloat16 , # ab_dtype
3808+ cutlass .Float32 , # acc_dtype
3809+ cutlass .BFloat16 , # c_dtype
3810+ use_2cta_instrs ,
3811+ mma_tiler_mn ,
3812+ cluster_shape_mn ,
3813+ m ,
3814+ n ,
3815+ k ,
3816+ batch_size ,
3817+ a_major ,
3818+ b_major ,
3819+ c_major ,
3820+ )
3821+ ]
3822+
3823+ def forward (
3824+ self ,
3825+ inputs : List [torch .Tensor ],
3826+ tactic ,
3827+ ) -> None :
3828+ """
3829+ Performs bf16 dense persistent batched gemm using CuTe DSL.
3830+
3831+ Args:
3832+ inputs (List[torch.Tensor]):
3833+ inputs[0]: Input tensor of shape (batch_size, m, k), dtype: bf16.
3834+ inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: bf16.
3835+ inputs[2]: Output tensor of shape (batch_size, m, n), dtype: bf16.
3836+ tactic: Tiling and cluster strategy, typically a tuple
3837+ (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
3838+ """
3839+ if isinstance (tactic , tuple ):
3840+ use_2cta_instrs , mma_tiler_mn , cluster_shape_mn = tactic
3841+ else :
3842+ use_2cta_instrs , mma_tiler_mn , cluster_shape_mn = [
3843+ False ,
3844+ (128 , 128 ),
3845+ (1 , 1 ),
3846+ ]
3847+
3848+ a_tensor , b_tensor , c_tensor = inputs
3849+
3850+ # Ensure A and B are contiguous — the kernel constructs CuTe
3851+ # layouts via make_ordered_layout assuming contiguous [B, M, K]
3852+ # and [B, N, K]. Transpose views (e.g. from .transpose(0,1))
3853+ # have swapped batch/seq strides which would cause the kernel
3854+ # to read from wrong memory locations.
3855+ a_tensor = a_tensor .contiguous ()
3856+ b_tensor = b_tensor .contiguous ()
3857+
3858+ # For the output, use a contiguous buffer so TMA store sees a
3859+ # standard layout; copy back afterwards if the original was
3860+ # non-contiguous.
3861+ c_needs_copy = not c_tensor .is_contiguous ()
3862+ if c_needs_copy :
3863+ c_buf = torch .empty_like (c_tensor )
3864+ else :
3865+ c_buf = c_tensor
3866+
3867+ # c_buf is [B, M, N], permute to [M, N, B] for cute layout
3868+ c_tmp = c_buf .permute (1 , 2 , 0 )
3869+
3870+ batch_size = a_tensor .shape [0 ]
3871+ m = a_tensor .shape [1 ]
3872+ k = a_tensor .shape [2 ]
3873+ n = b_tensor .shape [1 ]
3874+
3875+ if not self .use_tvm_ffi :
3876+ a_ptr = make_ptr (
3877+ cutlass .BFloat16 ,
3878+ a_tensor .data_ptr (),
3879+ cute .AddressSpace .gmem ,
3880+ assumed_align = 16 ,
3881+ )
3882+ b_ptr = make_ptr (
3883+ cutlass .BFloat16 ,
3884+ b_tensor .data_ptr (),
3885+ cute .AddressSpace .gmem ,
3886+ assumed_align = 16 ,
3887+ )
3888+ c_cute_tensor = cute .runtime .from_dlpack (
3889+ c_tmp ).mark_layout_dynamic (leading_dim = 1 )
3890+
3891+ stream = cuda .CUstream (torch .cuda .current_stream ().cuda_stream )
3892+
3893+ cache_key = (
3894+ use_2cta_instrs ,
3895+ mma_tiler_mn ,
3896+ cluster_shape_mn ,
3897+ self .use_tvm_ffi ,
3898+ )
3899+ if cache_key not in self .__class__ .kernel_cache :
3900+ if self .use_tvm_ffi :
3901+ a_ptr = make_ptr (
3902+ cutlass .BFloat16 ,
3903+ a_tensor .data_ptr (),
3904+ cute .AddressSpace .gmem ,
3905+ assumed_align = 16 ,
3906+ )
3907+ b_ptr = make_ptr (
3908+ cutlass .BFloat16 ,
3909+ b_tensor .data_ptr (),
3910+ cute .AddressSpace .gmem ,
3911+ assumed_align = 16 ,
3912+ )
3913+ c_cute_tensor = cute .runtime .from_dlpack (
3914+ c_tmp ).mark_layout_dynamic (leading_dim = 1 )
3915+ stream = cute .runtime .make_fake_stream (
3916+ use_tvm_ffi_env_stream = True )
3917+
3918+ gemm = self .__class__ .kernel_class (
3919+ cutlass .Float32 , # acc_dtype
3920+ use_2cta_instrs = use_2cta_instrs ,
3921+ mma_tiler_mn = mma_tiler_mn ,
3922+ cluster_shape_mn = cluster_shape_mn ,
3923+ )
3924+ hardware_info = cutlass .utils .HardwareInfo ()
3925+ max_active_clusters = hardware_info .get_max_active_clusters (
3926+ cluster_shape_mn [0 ] * cluster_shape_mn [1 ])
3927+
3928+ compiled_gemm = cute .compile (
3929+ gemm .wrapper ,
3930+ m ,
3931+ n ,
3932+ k ,
3933+ batch_size ,
3934+ a_ptr ,
3935+ b_ptr ,
3936+ c_cute_tensor ,
3937+ max_active_clusters = max_active_clusters ,
3938+ stream = stream ,
3939+ options = f"--opt-level 2 --enable-tvm-ffi"
3940+ if self .use_tvm_ffi else "--opt-level 2" ,
3941+ )
3942+ self .__class__ .kernel_cache [cache_key ] = compiled_gemm
3943+ else :
3944+ compiled_gemm = self .__class__ .kernel_cache [cache_key ]
3945+
3946+ # launch gemm kernel
3947+ if self .use_tvm_ffi :
3948+ compiled_gemm (
3949+ m ,
3950+ n ,
3951+ k ,
3952+ batch_size ,
3953+ a_tensor .data_ptr (),
3954+ b_tensor .data_ptr (),
3955+ c_tmp ,
3956+ )
3957+ else :
3958+ compiled_gemm (
3959+ m ,
3960+ n ,
3961+ k ,
3962+ batch_size ,
3963+ a_ptr ,
3964+ b_ptr ,
3965+ c_cute_tensor ,
3966+ stream = stream ,
3967+ )
3968+
3969+ # Copy result back if original output was non-contiguous
3970+ if c_needs_copy :
3971+ c_tensor .copy_ (c_buf )
3972+
3973+ # a/b: bf16, output: bf16
3974+ @torch .library .custom_op ("trtllm::cute_dsl_bf16_bmm_blackwell" ,
3975+ mutates_args = ("output" , ),
3976+ device_types = "cuda" )
3977+ def cute_dsl_bf16_bmm_blackwell (
3978+ input : torch .Tensor ,
3979+ weight : torch .Tensor ,
3980+ output : torch .Tensor ,
3981+ use_tvm_ffi : bool = True ,
3982+ ) -> None :
3983+ if not is_sm_100f ():
3984+ raise ValueError (
3985+ f"CuteDSL: SM version { get_sm_version ()} is not supported. "
3986+ f"CuteDSL BF16 BMM only supports SM 100 family." )
3987+
3988+ tuner = AutoTuner .get ()
3989+
3990+ runner = CuteDSLBf16BlackwellBmmRunner (use_tvm_ffi = use_tvm_ffi )
3991+
3992+ inputs = [input , weight , output ]
3993+
3994+ _ , best_tactic = tuner .choose_one (
3995+ "trtllm::cute_dsl_bf16_bmm_blackwell::gemm" ,
3996+ [runner ],
3997+ runner .__class__ .tuning_config ,
3998+ inputs ,
3999+ )
4000+ runner (inputs , tactic = best_tactic )
4001+
4002+ @torch .library .register_fake ("trtllm::cute_dsl_bf16_bmm_blackwell" )
4003+ def _ (
4004+ mat_a : torch .Tensor ,
4005+ mat_b : torch .Tensor ,
4006+ output : torch .Tensor ,
4007+ use_tvm_ffi : bool = True ,
4008+ ) -> None :
4009+ batch_size , m , k = mat_a .shape [0 ], mat_a .shape [1 ], mat_a .shape [2 ]
4010+ n = mat_b .shape [1 ]
4011+ assert output .dtype == torch .bfloat16 , "CuTe DSL bf16 bmm output dtype must be bf16"
4012+ assert output .shape == (batch_size , m ,
4013+ n ), "CuTe DSL bf16 bmm output shape is incorrect"
0 commit comments