Skip to content

Commit a4ddf26

Browse files
authored
PDL patch for TGV GEMM (#1877)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues Add missing nvcc flags to TGV gemm including the one that enables PDL feature. Also add safe guarding checks to TGV gemm to prevent tensor size being too large for TMA. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 792dcb1 commit a4ddf26

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

csrc/tgv_gemm.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ Tensor tgv_gemm(Tensor const& mat1, Tensor const& mat2, Optional<Tensor> bias, i
148148
int K = mat1->shape[1];
149149
int N = mat2->shape[1];
150150

151+
int64_t element_size = get_element_size(mat1);
152+
TVM_FFI_ICHECK(int64_t(M) * N * element_size < std::numeric_limits<int32_t>::max())
153+
<< "TMA plane stride (M * N * element_size) exceeds INT32_MAX; tensor too large for TMA";
154+
TVM_FFI_ICHECK(int64_t(M) * K * element_size < std::numeric_limits<int32_t>::max())
155+
<< "TMA plane stride (M * K * element_size) exceeds INT32_MAX; mat1 too large for TMA";
156+
TVM_FFI_ICHECK(int64_t(N) * K * element_size < std::numeric_limits<int32_t>::max())
157+
<< "TMA plane stride (N * K * element_size) exceeds INT32_MAX; mat2 too large for TMA";
158+
151159
// validity check for bias
152160
if (bias.has_value()) {
153161
TVM_FFI_ICHECK_EQ(bias.value()->device.device_type, kDLCUDA) << "Bias tensor must be on CUDA";

flashinfer/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def tgv_gemm_sm100(
584584
pdl: Whether to use PDL (persistent data loader), defaults to False
585585
586586
Returns:
587-
Output tensor of shape (M, N)
587+
Output tensor of shape (M, N) in row-major layout
588588
589589
Supported dtypes:
590590
- torch.bfloat16

flashinfer/jit/gemm/core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,13 @@ def gen_tgv_gemm_sm10x_module(
445445
return gen_jit_spec(
446446
module_name,
447447
source_paths,
448-
extra_cuda_cflags=sm100f_nvcc_flags if use_sm_100f else sm100a_nvcc_flags,
448+
extra_cuda_cflags=[
449+
"--expt-relaxed-constexpr",
450+
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
451+
]
452+
+ sm100f_nvcc_flags
453+
if use_sm_100f
454+
else sm100a_nvcc_flags,
449455
extra_include_paths=[
450456
jit_env.FLASHINFER_INCLUDE_DIR,
451457
jit_env.FLASHINFER_CSRC_DIR,

0 commit comments

Comments
 (0)