@@ -4011,3 +4011,284 @@ def _(
40114011 assert output .dtype == torch .bfloat16 , "CuTe DSL bf16 bmm output dtype must be bf16"
40124012 assert output .shape == (batch_size , m ,
40134013 n ), "CuTe DSL bf16 bmm output shape is incorrect"
4014+
4015+ # ======================================================================
4016+ # BF16 Dense Persistent GEMM (CuTe DSL) for Blackwell - Linear layers
4017+ # ======================================================================
4018+
4019+ class CuteDSLBf16BlackwellGemmRunner (TunableRunner ):
4020+ """
4021+ CuTe DSL BF16 GEMM runner for Linear layers.
4022+
4023+ Unlike BMM which operates on [B, M, K] @ [B, N, K] -> [B, M, N],
4024+ GEMM operates on [M, K] @ [N, K]^T -> [M, N] (standard Linear).
4025+
4026+ We reuse PersistentDenseGemmKernel with batch_size=1.
4027+ """
4028+ kernel_class = PersistentDenseGemmKernel
4029+ kernel_cache = dict ()
4030+
4031+ tuning_config = TuningConfig (
4032+ dynamic_tensor_specs = (DynamicTensorSpec (
4033+ 0 , 0 , get_last_power_of_2_num_tokens_buckets ,
4034+ last_positive_power_of_2 ), ),
4035+ )
4036+
4037+ def __init__ (self , use_tvm_ffi : bool = True ):
4038+ super ().__init__ ()
4039+ self .use_tvm_ffi = use_tvm_ffi
4040+
4041+ def get_valid_tactics (
4042+ self ,
4043+ inputs : List [torch .Tensor ],
4044+ profile : OptimizationProfile ,
4045+ ** kwargs ,
4046+ ) -> List [int ]:
4047+
4048+ if not is_sm_100f ():
4049+ logger .debug (
4050+ f"CuteDSL: SM version { get_sm_version ()} is not supported. "
4051+ f"CuteDSL BF16 GEMM only supports SM 100 family. Skipping all tactics."
4052+ )
4053+ return []
4054+
4055+ # input: [M, K], weight: [N, K]
4056+ m , k = inputs [0 ].shape [0 ], inputs [0 ].shape [1 ]
4057+ n = inputs [1 ].shape [0 ]
4058+ batch_size = 1
4059+
4060+ # Layouts: A is [M, K] K-major, B is [N, K] K-major
4061+ a_major = "k"
4062+ b_major = "k"
4063+ c_major = "n"
4064+
4065+ use_2cta_instrs_candi = [False , True ]
4066+ mma_tiler_mn_candi = [(64 , 128 ), (128 , 128 ), (256 , 128 )]
4067+ cluster_shape_mn_candi = [
4068+ (1 , 1 ),
4069+ (1 , 2 ),
4070+ (1 , 4 ),
4071+ (2 , 1 ),
4072+ (2 , 2 ),
4073+ (2 , 4 ),
4074+ (4 , 1 ),
4075+ (4 , 2 ),
4076+ (4 , 4 ),
4077+ ]
4078+ return [
4079+ (use_2cta_instrs , mma_tiler_mn , cluster_shape_mn )
4080+ for use_2cta_instrs in use_2cta_instrs_candi
4081+ for mma_tiler_mn in mma_tiler_mn_candi
4082+ for cluster_shape_mn in cluster_shape_mn_candi
4083+ if self .__class__ .kernel_class .can_implement (
4084+ cutlass .BFloat16 , # ab_dtype
4085+ cutlass .Float32 , # acc_dtype
4086+ cutlass .BFloat16 , # c_dtype
4087+ use_2cta_instrs ,
4088+ mma_tiler_mn ,
4089+ cluster_shape_mn ,
4090+ m ,
4091+ n ,
4092+ k ,
4093+ batch_size ,
4094+ a_major ,
4095+ b_major ,
4096+ c_major ,
4097+ )
4098+ ]
4099+
4100+ def forward (
4101+ self ,
4102+ inputs : List [torch .Tensor ],
4103+ tactic ,
4104+ ) -> None :
4105+ """
4106+ Performs bf16 dense persistent GEMM using CuTe DSL.
4107+
4108+ Args:
4109+ inputs (List[torch.Tensor]):
4110+ inputs[0]: Input tensor of shape (m, k), dtype: bf16.
4111+ inputs[1]: Weight tensor of shape (n, k), dtype: bf16.
4112+ inputs[2]: Output tensor of shape (m, n), dtype: bf16.
4113+ tactic: Tiling and cluster strategy, typically a tuple
4114+ (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
4115+ """
4116+ if isinstance (tactic , tuple ):
4117+ use_2cta_instrs , mma_tiler_mn , cluster_shape_mn = tactic
4118+ else :
4119+ use_2cta_instrs , mma_tiler_mn , cluster_shape_mn = [
4120+ False ,
4121+ (128 , 128 ),
4122+ (1 , 1 ),
4123+ ]
4124+
4125+ a_tensor , b_tensor , c_tensor = inputs
4126+
4127+ # Input: [M, K], Weight: [N, K], Output: [M, N]
4128+ m , k = a_tensor .shape [0 ], a_tensor .shape [1 ]
4129+ n = b_tensor .shape [0 ]
4130+ batch_size = 1
4131+
4132+ # Ensure inputs are contiguous
4133+ a_tensor = a_tensor .contiguous ()
4134+ b_tensor = b_tensor .contiguous ()
4135+
4136+ # For output, use contiguous buffer if needed
4137+ c_needs_copy = not c_tensor .is_contiguous ()
4138+ if c_needs_copy :
4139+ c_buf = torch .empty_like (c_tensor )
4140+ else :
4141+ c_buf = c_tensor
4142+
4143+ # Reshape to [1, M, K], [1, N, K], [1, M, N] for the batched kernel
4144+ a_batched = a_tensor .unsqueeze (0 ) # [1, M, K]
4145+ b_batched = b_tensor .unsqueeze (0 ) # [1, N, K]
4146+ # c_buf is [M, N], permute to [M, N, 1] for cute layout
4147+ c_tmp = c_buf .unsqueeze (- 1 ) # [M, N, 1]
4148+
4149+ if not self .use_tvm_ffi :
4150+ a_ptr = make_ptr (
4151+ cutlass .BFloat16 ,
4152+ a_batched .data_ptr (),
4153+ cute .AddressSpace .gmem ,
4154+ assumed_align = 16 ,
4155+ )
4156+ b_ptr = make_ptr (
4157+ cutlass .BFloat16 ,
4158+ b_batched .data_ptr (),
4159+ cute .AddressSpace .gmem ,
4160+ assumed_align = 16 ,
4161+ )
4162+ c_cute_tensor = cute .runtime .from_dlpack (
4163+ c_tmp ).mark_layout_dynamic (leading_dim = 1 )
4164+
4165+ stream = cuda .CUstream (torch .cuda .current_stream ().cuda_stream )
4166+
4167+ cache_key = (
4168+ use_2cta_instrs ,
4169+ mma_tiler_mn ,
4170+ cluster_shape_mn ,
4171+ self .use_tvm_ffi ,
4172+ )
4173+ if cache_key not in self .__class__ .kernel_cache :
4174+ if self .use_tvm_ffi :
4175+ a_ptr = make_ptr (
4176+ cutlass .BFloat16 ,
4177+ a_batched .data_ptr (),
4178+ cute .AddressSpace .gmem ,
4179+ assumed_align = 16 ,
4180+ )
4181+ b_ptr = make_ptr (
4182+ cutlass .BFloat16 ,
4183+ b_batched .data_ptr (),
4184+ cute .AddressSpace .gmem ,
4185+ assumed_align = 16 ,
4186+ )
4187+ c_cute_tensor = cute .runtime .from_dlpack (
4188+ c_tmp ).mark_layout_dynamic (leading_dim = 1 )
4189+ stream = cute .runtime .make_fake_stream (
4190+ use_tvm_ffi_env_stream = True )
4191+
4192+ gemm = self .__class__ .kernel_class (
4193+ cutlass .Float32 , # acc_dtype
4194+ use_2cta_instrs = use_2cta_instrs ,
4195+ mma_tiler_mn = mma_tiler_mn ,
4196+ cluster_shape_mn = cluster_shape_mn ,
4197+ )
4198+ hardware_info = cutlass .utils .HardwareInfo ()
4199+ max_active_clusters = hardware_info .get_max_active_clusters (
4200+ cluster_shape_mn [0 ] * cluster_shape_mn [1 ])
4201+
4202+ compiled_gemm = cute .compile (
4203+ gemm .wrapper ,
4204+ m ,
4205+ n ,
4206+ k ,
4207+ batch_size ,
4208+ a_ptr ,
4209+ b_ptr ,
4210+ c_cute_tensor ,
4211+ max_active_clusters = max_active_clusters ,
4212+ stream = stream ,
4213+ options = f"--opt-level 2 --enable-tvm-ffi"
4214+ if self .use_tvm_ffi else "--opt-level 2" ,
4215+ )
4216+ self .__class__ .kernel_cache [cache_key ] = compiled_gemm
4217+ else :
4218+ compiled_gemm = self .__class__ .kernel_cache [cache_key ]
4219+
4220+ # launch gemm kernel
4221+ if self .use_tvm_ffi :
4222+ compiled_gemm (
4223+ m ,
4224+ n ,
4225+ k ,
4226+ batch_size ,
4227+ a_batched .data_ptr (),
4228+ b_batched .data_ptr (),
4229+ c_tmp ,
4230+ )
4231+ else :
4232+ compiled_gemm (
4233+ m ,
4234+ n ,
4235+ k ,
4236+ batch_size ,
4237+ a_ptr ,
4238+ b_ptr ,
4239+ c_cute_tensor ,
4240+ stream = stream ,
4241+ )
4242+
4243+ # Copy result back if original output was non-contiguous
4244+ if c_needs_copy :
4245+ c_tensor .copy_ (c_buf )
4246+
4247+ # input: [M, K], weight: [N, K], output: [M, N]
4248+ @torch .library .custom_op ("trtllm::cute_dsl_bf16_gemm_blackwell" ,
4249+ mutates_args = ("output" , ),
4250+ device_types = "cuda" )
4251+ def cute_dsl_bf16_gemm_blackwell (
4252+ input : torch .Tensor ,
4253+ weight : torch .Tensor ,
4254+ output : torch .Tensor ,
4255+ use_tvm_ffi : bool = True ,
4256+ ) -> None :
4257+ """
4258+ CuTe DSL BF16 GEMM for Linear layers on Blackwell.
4259+
4260+ Computes: output = input @ weight^T
4261+ - input: [M, K] (num_tokens, in_features)
4262+ - weight: [N, K] (out_features, in_features)
4263+ - output: [M, N] (num_tokens, out_features)
4264+ """
4265+ if not is_sm_100f ():
4266+ raise ValueError (
4267+ f"CuteDSL: SM version { get_sm_version ()} is not supported. "
4268+ f"CuteDSL BF16 GEMM only supports SM 100 family." )
4269+
4270+ tuner = AutoTuner .get ()
4271+
4272+ runner = CuteDSLBf16BlackwellGemmRunner (use_tvm_ffi = use_tvm_ffi )
4273+
4274+ inputs = [input , weight , output ]
4275+
4276+ _ , best_tactic = tuner .choose_one (
4277+ "trtllm::cute_dsl_bf16_gemm_blackwell::gemm" ,
4278+ [runner ],
4279+ runner .__class__ .tuning_config ,
4280+ inputs ,
4281+ )
4282+ runner (inputs , tactic = best_tactic )
4283+
4284+ @torch .library .register_fake ("trtllm::cute_dsl_bf16_gemm_blackwell" )
4285+ def _ (
4286+ mat_a : torch .Tensor ,
4287+ mat_b : torch .Tensor ,
4288+ output : torch .Tensor ,
4289+ use_tvm_ffi : bool = True ,
4290+ ) -> None :
4291+ m , k = mat_a .shape [0 ], mat_a .shape [1 ]
4292+ n = mat_b .shape [0 ]
4293+ assert output .dtype == torch .bfloat16 , "CuTe DSL bf16 gemm output dtype must be bf16"
4294+ assert output .shape == (m , n ), "CuTe DSL bf16 gemm output shape is incorrect"
0 commit comments