@@ -250,8 +250,120 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
250
250
CUTLASS_CHECK (status);
251
251
}
252
252
253
+ template <typename InType, typename OutType,
254
+ template <typename , typename > typename Epilogue>
255
+ struct sm80_config_default {
256
+ // This config is used in 2 cases,
257
+ // - M in (128, inf)
258
+ // - M in (64, 128] and N >= 8192
259
+ static_assert (std::is_same<InType, int8_t >());
260
+ using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
261
+ using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
262
+ using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
263
+ using Cutlass2xGemm =
264
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
265
+ Epilogue, TileShape, WarpShape, InstructionShape, 5 >;
266
+ };
267
+
268
+ template <typename InType, typename OutType,
269
+ template <typename , typename > typename Epilogue>
270
+ struct sm80_config_M64 {
271
+ // This config is used in 2 cases,
272
+ // - M in (32, 64]
273
+ // - M in (64, 128] and N < 8192
274
+ static_assert (std::is_same<InType, int8_t >());
275
+ using TileShape = typename cutlass::gemm::GemmShape<64 , 128 , 128 >;
276
+ using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
277
+ using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
278
+ using Cutlass2xGemm =
279
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
280
+ Epilogue, TileShape, WarpShape, InstructionShape, 5 >;
281
+ };
282
+
283
+ template <typename InType, typename OutType,
284
+ template <typename , typename > typename Epilogue>
285
+ struct sm80_config_M32 {
286
+ // M in (16, 32]
287
+ static_assert (std::is_same<InType, int8_t >());
288
+ using TileShape = typename cutlass::gemm::GemmShape<32 , 64 , 128 >;
289
+ using WarpShape = typename cutlass::gemm::GemmShape<32 , 64 , 64 >;
290
+ using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
291
+ using Cutlass2xGemm =
292
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
293
+ Epilogue, TileShape, WarpShape, InstructionShape, 5 >;
294
+ };
295
+
296
+ template <typename InType, typename OutType,
297
+ template <typename , typename > typename Epilogue>
298
+ struct sm80_config_M16 {
299
+ // M in [1, 16]
300
+ static_assert (std::is_same<InType, int8_t >());
301
+ using TileShape = typename cutlass::gemm::GemmShape<16 , 64 , 128 >;
302
+ using WarpShape = typename cutlass::gemm::GemmShape<16 , 64 , 64 >;
303
+ using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
304
+ using Cutlass2xGemm =
305
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
306
+ Epilogue, TileShape, WarpShape, InstructionShape, 5 >;
307
+ };
308
+
253
309
} // namespace
254
310
311
+ template <typename InType, typename OutType,
312
+ template <typename , typename > typename Epilogue,
313
+ typename ... EpilogueArgs>
314
+ void cutlass_gemm_sm80_dispatch (torch::Tensor& out, torch::Tensor const & a,
315
+ torch::Tensor const & b,
316
+ EpilogueArgs&&... args) {
317
+ static_assert (std::is_same<InType, int8_t >());
318
+ TORCH_CHECK (a.dtype () == torch::kInt8 );
319
+ TORCH_CHECK (b.dtype () == torch::kInt8 );
320
+
321
+ using Cutlass2xGemmDefault =
322
+ typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
323
+ using Cutlass2xGemmM128BigN =
324
+ typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
325
+ using Cutlass2xGemmM128SmallN =
326
+ typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
327
+ using Cutlass2xGemmM64 =
328
+ typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
329
+ using Cutlass2xGemmM32 =
330
+ typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
331
+ using Cutlass2xGemmM16 =
332
+ typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
333
+
334
+ uint32_t const m = a.size (0 );
335
+ uint32_t const mp2 =
336
+ std::max (static_cast <uint32_t >(16 ), next_pow_2 (m)); // next power of 2
337
+ if (mp2 <= 16 ) {
338
+ // M in [1, 16]
339
+ return cutlass_gemm_caller<Cutlass2xGemmM16>(
340
+ out, a, b, std::forward<EpilogueArgs>(args)...);
341
+ } else if (mp2 <= 32 ) {
342
+ // M in (16, 32]
343
+ return cutlass_gemm_caller<Cutlass2xGemmM32>(
344
+ out, a, b, std::forward<EpilogueArgs>(args)...);
345
+ } else if (mp2 <= 64 ) {
346
+ // M in (32, 64]
347
+ return cutlass_gemm_caller<Cutlass2xGemmM64>(
348
+ out, a, b, std::forward<EpilogueArgs>(args)...);
349
+ } else if (mp2 <= 128 ) {
350
+ // M in (64, 128]
351
+ uint32_t const n = out.size (1 );
352
+ bool const small_n = n < 8192 ;
353
+ if (small_n) {
354
+ return cutlass_gemm_caller<Cutlass2xGemmM128SmallN>(
355
+ out, a, b, std::forward<EpilogueArgs>(args)...);
356
+ } else {
357
+ return cutlass_gemm_caller<Cutlass2xGemmM128BigN>(
358
+ out, a, b, std::forward<EpilogueArgs>(args)...);
359
+ }
360
+ } else {
361
+ // M in (128, inf)
362
+ return cutlass_gemm_caller<Cutlass2xGemmDefault>(
363
+ out, a, b, std::forward<EpilogueArgs>(args)...);
364
+ }
365
+ }
366
+
255
367
void cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
256
368
torch::Tensor const & b,
257
369
torch::Tensor const & a_scales,
@@ -288,20 +400,13 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
288
400
TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
289
401
TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
290
402
291
- using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
292
- using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
293
- using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
294
-
295
403
if (out.dtype () == torch::kBFloat16 ) {
296
- return cutlass_gemm_caller<cutlass_2x_gemm<
297
- cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t , cutlass::bfloat16_t ,
298
- ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
299
- out, a, b, a_scales, b_scales);
404
+ return cutlass_gemm_sm80_dispatch<int8_t , cutlass::bfloat16_t ,
405
+ ScaledEpilogue>(out, a, b, a_scales,
406
+ b_scales);
300
407
} else {
301
408
TORCH_CHECK (out.dtype () == torch::kFloat16 );
302
- return cutlass_gemm_caller<cutlass_2x_gemm<
303
- cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t , cutlass::half_t ,
304
- ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
409
+ return cutlass_gemm_sm80_dispatch<int8_t , cutlass::half_t , ScaledEpilogue>(
305
410
out, a, b, a_scales, b_scales);
306
411
}
307
412
}
0 commit comments