Skip to content

Commit 3bac3be

Browse files
authored
[FRONTEND] rename nv_override_capability -> arch (#5579)
1 parent 2efb067 commit 3bac3be

File tree

3 files changed

+33
-22
lines changed

3 files changed

+33
-22
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3333
"TRITON_HIP_USE_BLOCK_PINGPONG",
3434
"TRITON_LLVM_DEBUG_ONLY",
3535
"TRITON_ENABLE_ASAN",
36-
"TRITON_OVERRIDE_NV_CAPABILITY",
36+
"TRITON_OVERRIDE_ARCH",
3737
"USE_IR_LOC",
3838
"NVPTX_ENABLE_DUMP",
3939
// clang-format on

python/test/unit/language/test_core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6067,15 +6067,15 @@ def mul_add(data):
60676067

60686068

60696069
# -----------------------
6070-
# test override_nv_compute_capability
6070+
# test override_arch
60716071
# -----------------------
60726072

60736073

6074-
@pytest.mark.parametrize("nv_compute_capability", [70, 80, 90])
6074+
@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90"])
60756075
@pytest.mark.parametrize("env_var_override", [False, True])
6076-
def test_override_nv_compute_capability(nv_compute_capability, env_var_override, device):
6076+
def test_override_arch(arch, env_var_override, device):
60776077
if not is_cuda():
6078-
pytest.skip('test_override_nv_compute_capability only for CUDA')
6078+
pytest.skip('arch only for CUDA')
60796079

60806080
@triton.jit
60816081
def simple(data, out):
@@ -6087,14 +6087,14 @@ def simple(data, out):
60876087
out = torch.empty_like(data)
60886088

60896089
if env_var_override:
6090-
os.environ["TRITON_OVERRIDE_NV_CAPABILITY"] = str(nv_compute_capability)
6090+
os.environ["TRITON_OVERRIDE_ARCH"] = str(arch)
60916091
h = simple[(1, )](data, out)
6092-
os.environ.pop("TRITON_OVERRIDE_NV_CAPABILITY")
6092+
os.environ.pop("TRITON_OVERRIDE_ARCH")
60936093
else:
6094-
h = simple[(1, )](data, out, override_nv_compute_capability=nv_compute_capability)
6094+
h = simple[(1, )](data, out, arch=arch)
60956095
torch.testing.assert_close(data * 1.5 + 1.0, out)
60966096
ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"])
6097-
assert int(ttgir_cc.group(1)) == nv_compute_capability
6097+
assert ttgir_cc.group(1) == arch[2:]
60986098

60996099

61006100
# -----------------------

third_party/nvidia/backend/compiler.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class CUDAOptions:
122122
debug: bool = False
123123
backend_name: str = 'cuda'
124124
sanitize_overflow: bool = True
125-
override_nv_compute_capability: int = None
125+
arch: str = None
126126

127127
def __post_init__(self):
128128
default_libdir = Path(__file__).parent / 'lib'
@@ -146,34 +146,45 @@ class CUDABackend(BaseBackend):
146146
def supports_target(target: GPUTarget):
147147
return target.backend == 'cuda'
148148

149+
def _parse_arch(self, arch):
150+
pattern = r"^sm(\d+)$"
151+
match = re.fullmatch(pattern, arch)
152+
if not match:
153+
raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
154+
return int(match.group(1))
155+
149156
def __init__(self, target: GPUTarget) -> None:
150157
super().__init__(target)
151158
# Capability can be overrided to limit feature set to a specific version
152-
cap_override = os.getenv("TRITON_OVERRIDE_NV_CAPABILITY")
153-
self.capability = int(cap_override) if cap_override is not None else target.arch
159+
self.hw_capability = target.arch
160+
self.sw_capability = self.hw_capability
161+
arch = os.getenv("TRITON_OVERRIDE_ARCH")
162+
if arch is not None:
163+
self.sw_capability = self._parse_arch(arch)
154164
# HW Capability is used to determine the binary format
155165
self.hw_capability = target.arch
156-
assert isinstance(self.capability, int)
166+
assert isinstance(self.hw_capability, int)
167+
assert isinstance(self.sw_capability, int)
157168
self.binary_ext = "cubin"
158169

159170
def parse_options(self, opts) -> Any:
160171
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
161172
if "supported_fp8_dtypes" not in args:
162173
supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
163-
if self.capability >= 89:
174+
if self.sw_capability >= 89:
164175
supported_fp8_dtypes.add("fp8e4nv")
165176
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
166177

167178
if "deprecated_fp8_dtypes" not in args:
168-
if self.capability >= 90:
179+
if self.sw_capability >= 90:
169180
args["deprecated_fp8_dtypes"] = ("fp8e4b15", )
170181

171182
if "enable_fp_fusion" not in args:
172183
args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
173184

174-
if "override_nv_compute_capability" in args and args["override_nv_compute_capability"] is not None:
175-
self.capability = args["override_nv_compute_capability"]
176-
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
185+
if args.get("arch", None) is not None:
186+
self.sw_capability = self._parse_arch(args["arch"])
187+
args["max_num_imprecise_acc_default"] = 2**30 if self.sw_capability == 90 else 0
177188
return CUDAOptions(**args)
178189

179190
def pack_metadata(self, metadata):
@@ -190,7 +201,7 @@ def get_codegen_implementation(self):
190201
import triton.language.extra.cuda as cuda
191202
codegen_fns = {
192203
"convert_custom_types":
193-
cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70,
204+
cuda.convert_custom_float8_sm80 if self.sw_capability >= 80 else cuda.convert_custom_float8_sm70,
194205
"min_dot_size": min_dot_size(self.target)
195206
}
196207
return codegen_fns
@@ -401,12 +412,12 @@ def make_cubin(src, metadata, opt, capability):
401412

402413
def add_stages(self, stages, options):
403414
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
404-
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
405-
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
415+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.sw_capability)
416+
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.sw_capability)
406417
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.hw_capability)
407418
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.hw_capability)
408419

409420
@functools.lru_cache()
410421
def hash(self):
411422
version = get_ptxas_version()
412-
return f'{version}-{self.capability}'
423+
return f'{version}-{self.sw_capability}-{self.hw_capability}'

0 commit comments

Comments
 (0)