Skip to content

Commit 66b231b

Browse files
authored
[FRONTEND] Better error for num_cta > 1 on sm < 90 (#7812)
Fixes triton-lang/triton#7811
1 parent 8b2ed6d commit 66b231b

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

python/test/unit/language/test_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7190,6 +7190,26 @@ def simple(data, out):
71907190
assert amdgcn_gfx.group(1) == arch
71917191

71927192

7193+
def test_num_ctas_pre_sm90(device):
7194+
if not is_cuda() and not is_hip():
7195+
pytest.skip("Only supported on CUDA and HIP")
7196+
7197+
@triton.jit
7198+
def _kernel(src):
7199+
pass
7200+
7201+
src = torch.empty(1, device=device)
7202+
if is_cuda():
7203+
arch = "sm80"
7204+
msg = r"num_ctas > 1 requires NVIDIA SM90\+ \(Hopper\)"
7205+
else:
7206+
arch = "gfx942"
7207+
msg = r"num_ctas > 1 not supported for AMD GPUs"
7208+
7209+
with pytest.raises(ValueError, match=msg):
7210+
_kernel.warmup(src, grid=(1, ), num_ctas=2, arch=arch)
7211+
7212+
71937213
# -----------------------
71947214
# test propagate_nan
71957215
# -----------------------

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def get_target_name(self, options) -> str:
110110
def parse_options(self, opts) -> Any:
111111
args = {'arch': knobs.runtime.override_arch or self.target.arch}
112112

113+
if opts.get("num_ctas", 1) > 1:
114+
raise ValueError("num_ctas > 1 not supported for AMD GPUs")
115+
113116
# Enable XF32 (TF32) for CDNA3 GPUs
114117
if self.target.arch == 'gfx942':
115118
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)

third_party/nvidia/backend/compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def parse_options(self, opts) -> Any:
174174
args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
175175
capability = int(self._parse_arch(args["arch"]))
176176

177+
if args.get("num_ctas", 1) > 1 and capability < 90:
178+
raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
179+
f"Current target is sm_{capability}. This configuration will fail. "
180+
f"Please set num_ctas=1 or target an SM90+ GPU."))
181+
177182
if "supported_fp8_dtypes" not in args:
178183
supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
179184
if capability >= 89:

0 commit comments

Comments
 (0)