Skip to content

Commit 2d4aafe

Browse files
committed
resolve
2 parents b918359 + ecd6648 commit 2d4aafe

10 files changed

+335
-148
lines changed

.github/workflows/new-issue.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Triage new issues
2+
3+
on:
4+
issues:
5+
types: [opened]
6+
7+
permissions:
8+
issues: write
9+
10+
jobs:
11+
triage:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- name: Add needs-triage label
15+
uses: actions/github-script@v7
16+
with:
17+
github-token: ${{ secrets.GITHUB_TOKEN }}
18+
script: |
19+
const issueNumber = context.issue.number;
20+
const { owner, repo } = context.repo;
21+
const labelName = 'needs-triage';
22+
try {
23+
await github.rest.repos.getLabel({ owner, repo, name: labelName });
24+
} catch (error) {
25+
if (error.status === 404) {
26+
throw new Error(`Required label '${labelName}' does not exist in ${owner}/${repo}. Please create it in the repository settings.`);
27+
}
28+
throw error;
29+
}
30+
await github.rest.issues.addLabels({
31+
owner,
32+
repo,
33+
issue_number: issueNumber,
34+
labels: [labelName],
35+
});

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
7070

7171

7272
class ArtifactPath:
73-
TRTLLM_GEN_FMHA: str = "9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/fmha/trtllm-gen/"
73+
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen/"
7474
TRTLLM_GEN_BMM: str = (
7575
"9ef9e6243df03ab2c3fca1f0398a38cf1011d1e1/batched_gemm-45beda1-7bdba93/"
7676
)
@@ -83,7 +83,7 @@ class ArtifactPath:
8383

8484
class MetaInfoHash:
8585
TRTLLM_GEN_FMHA: str = (
86-
"875f50e8f466120b1a59b94397835b86fad785942b4036823230465bc618b919"
86+
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
8787
)
8888
TRTLLM_GEN_BMM: str = (
8989
"9490085267aed30a387bfff024a0605e1ca4d39dfe06a5abc159d7d7e129bdf4"

flashinfer/gemm.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
last_positive_power_of_2,
3939
)
4040
from .jit.cubin_loader import get_cubin
41-
from .utils import is_sm100a_supported, is_sm120a_supported, is_sm121a_supported
41+
from .utils import (
42+
is_sm100a_supported,
43+
is_sm120a_supported,
44+
is_sm121a_supported,
45+
LibraryError,
46+
)
4247

4348
CUDNN_AVAILABLE = False
4449
try:
@@ -2112,6 +2117,15 @@ def mm_fp4(
21122117
raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.")
21132118
if backend != "cudnn" and not use_nvfp4:
21142119
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")
2120+
if (
2121+
backend == "cudnn"
2122+
and not use_nvfp4
2123+
and _match_sm_version(a.device, ["120"])
2124+
and cudnn.backend_version() < 91400
2125+
):
2126+
raise LibraryError(
2127+
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
2128+
)
21152129

21162130
# allocate the output tensor if not provided
21172131
if out is None:
@@ -3078,6 +3092,11 @@ def group_deepgemm_fp8_nt_groupwise(
30783092
"""
30793093
from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
30803094

3095+
if not _match_sm_version(a.device, ["100", "103"]):
3096+
raise ValueError(
3097+
"m_grouped_fp8_gemm_nt_contiguous is only supported on SM100, SM100, SM103."
3098+
)
3099+
30813100
if out is None:
30823101
out_dtype = out_dtype or torch.bfloat16
30833102
out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device)
@@ -3206,6 +3225,11 @@ def batch_deepgemm_fp8_nt_groupwise(
32063225
"""
32073226
from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_masked
32083227

3228+
if not _match_sm_version(a.device, ["100", "103"]):
3229+
raise ValueError(
3230+
"m_grouped_fp8_gemm_nt_masked is only supported on SM100, SM103."
3231+
)
3232+
32093233
if out is None:
32103234
out_dtype = out_dtype or torch.bfloat16
32113235
out = torch.empty(

flashinfer/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ class TensorLayout(Enum):
5252

5353

5454
class GPUArchitectureError(Exception):
55-
def __init__(self, msg: str):
56-
self.msg = msg
57-
super().__init__(self.msg)
55+
"""Custom exception for GPU architecture-related errors."""
5856

59-
def __str__(self):
60-
return self.msg
57+
pass
6158

62-
def __repr__(self):
63-
return self.msg
59+
60+
class LibraryError(Exception):
61+
"""Custom exception for library-related errors."""
62+
63+
pass
6464

6565

6666
def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor:

tests/test_attention_sink_blackwell.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import pytest
1919
import torch
2020
from sink_attention_reference import sink_attention_unified
21-
from conftest import assert_close_with_mismatch_tolerance
2221

2322
import flashinfer
2423
from flashinfer.utils import get_compute_capability
@@ -122,13 +121,7 @@ def test_blackwell_trtllm_gen_decode_attention_sink(
122121
else:
123122
raise ValueError(f"Unsupported dtype: {dtype}")
124123

125-
assert_close_with_mismatch_tolerance(
126-
o_ref,
127-
output,
128-
atol=atol,
129-
rtol=rtol,
130-
max_mismatched_elements=int(output.numel() * 0.01),
131-
)
124+
torch.testing.assert_close(o_ref, output, atol=atol, rtol=rtol)
132125

133126

134127
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])

tests/test_groupwise_scaled_gemm_fp8.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ def test_fp8_groupwise_group_deepgemm(
202202
group_size,
203203
out_dtype,
204204
):
205+
compute_capability = get_compute_capability(torch.device(device="cuda"))
206+
if compute_capability[0] != 10:
207+
pytest.skip(
208+
"group_deepgemm_fp8_nt_groupwise is only supported on SM100, SM103 in trtllm backend."
209+
)
205210
torch.random.manual_seed(0)
206211
m_per_group = m // group_size
207212
if m_per_group < 128:
@@ -245,6 +250,11 @@ def test_fp8_groupwise_batch_deepgemm_masked(
245250
group_size,
246251
out_dtype,
247252
):
253+
compute_capability = get_compute_capability(torch.device(device="cuda"))
254+
if compute_capability[0] != 10:
255+
pytest.skip(
256+
"batch_deepgemm_fp8_nt_groupwise is only supported on SM100, SM103."
257+
)
248258
torch.random.manual_seed(0)
249259
n, k = nk
250260
a = torch.randn((group_size, m, k), device="cuda", dtype=torch.float32)

tests/test_mm_fp4.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
nvfp4_quantize,
99
mxfp4_quantize,
1010
)
11-
from flashinfer.utils import get_compute_capability
11+
from flashinfer.utils import get_compute_capability, LibraryError
1212

1313

1414
# TODO: Consdier splitting this function up for the various backends
@@ -25,10 +25,10 @@ def test_mm_fp4(
2525
):
2626
use_nvfp4 = fp4_type == "nvfp4"
2727

28+
compute_capability = get_compute_capability(torch.device(device="cuda"))
2829
if backend == "trtllm":
2930
if res_dtype == torch.float16:
3031
pytest.skip("Skipping test for trtllm fp4 with float16")
31-
compute_capability = get_compute_capability(torch.device(device="cuda"))
3232
if compute_capability[0] in [11, 12]:
3333
pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.")
3434
if not use_128x4_sf_layout and backend != "trtllm":
@@ -71,23 +71,36 @@ def test_mm_fp4(
7171

7272
res = torch.empty([m, n], device="cuda", dtype=res_dtype)
7373

74-
with autotune(auto_tuning):
75-
mm_fp4(
76-
input_fp4,
77-
mat2_fp4.T,
78-
input_inv_s,
79-
mat2_inv_s.T,
80-
alpha,
81-
res_dtype,
82-
res,
83-
block_size=block_size,
84-
use_8x4_sf_layout=not use_128x4_sf_layout,
85-
backend=backend,
86-
use_nvfp4=use_nvfp4,
87-
)
74+
try:
75+
with autotune(auto_tuning):
76+
mm_fp4(
77+
input_fp4,
78+
mat2_fp4.T,
79+
input_inv_s,
80+
mat2_inv_s.T,
81+
alpha,
82+
res_dtype,
83+
res,
84+
block_size=block_size,
85+
use_8x4_sf_layout=not use_128x4_sf_layout,
86+
backend=backend,
87+
use_nvfp4=use_nvfp4,
88+
)
8889

89-
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
90-
assert cos_sim > 0.97
90+
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
91+
assert cos_sim > 0.97
92+
except LibraryError:
93+
# TODO: Remove this check once cuDNN backend version is updated to 9.14.0
94+
if (
95+
backend == "cudnn"
96+
and not use_nvfp4
97+
and (compute_capability[0] == 12 and compute_capability[1] == 0)
98+
):
99+
pytest.xfail(
100+
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
101+
)
102+
else:
103+
pytest.fail("Unexpected LibraryError")
91104

92105

93106
if __name__ == "__main__":

tests/test_triton_cascade.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_merge_state(seq_len, num_heads, head_dim):
2121
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
2222
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
2323
except GPUArchitectureError as e:
24-
pytest.skip(e.msg)
24+
pytest.skip(str(e))
2525

2626

2727
@pytest.mark.parametrize("seq_len", [2048])
@@ -44,7 +44,7 @@ def test_merge_state_in_place(seq_len, num_heads, head_dim):
4444
assert torch.allclose(s, s_std, atol=1e-2)
4545

4646
except GPUArchitectureError as e:
47-
pytest.skip(e.msg)
47+
pytest.skip(str(e))
4848

4949

5050
@pytest.mark.parametrize("seq_len", [2048])
@@ -63,7 +63,7 @@ def test_merge_states(seq_len, num_states, num_heads, head_dim):
6363
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
6464
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
6565
except GPUArchitectureError as e:
66-
pytest.skip(e.msg)
66+
pytest.skip(str(e))
6767

6868

6969
@pytest.mark.parametrize("seq_len", [2048])
@@ -94,4 +94,4 @@ def test_variable_length_merge_states(seq_len, num_heads, head_dim):
9494
assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2)
9595
assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)
9696
except GPUArchitectureError as e:
97-
pytest.skip(e.msg)
97+
pytest.skip(str(e))

0 commit comments

Comments
 (0)