@@ -122,7 +122,7 @@ class CUDAOptions:
122122 debug : bool = False
123123 backend_name : str = 'cuda'
124124 sanitize_overflow : bool = True
125- override_nv_compute_capability : int = None
125+ arch : str = None
126126
127127 def __post_init__ (self ):
128128 default_libdir = Path (__file__ ).parent / 'lib'
@@ -146,34 +146,45 @@ class CUDABackend(BaseBackend):
146146 def supports_target (target : GPUTarget ):
147147 return target .backend == 'cuda'
148148
149+ def _parse_arch (self , arch ):
150+ pattern = r"^sm(\d+)$"
151+ match = re .fullmatch (pattern , arch )
152+ if not match :
153+ raise ValueError (f"TRITON_OVERRIDE_ARCH must have the form { pattern } " )
154+ return int (match .group (1 ))
155+
149156 def __init__ (self , target : GPUTarget ) -> None :
150157 super ().__init__ (target )
151158 # Capability can be overrided to limit feature set to a specific version
152- cap_override = os .getenv ("TRITON_OVERRIDE_NV_CAPABILITY" )
153- self .capability = int (cap_override ) if cap_override is not None else target .arch
159+ self .hw_capability = target .arch
160+ self .sw_capability = self .hw_capability
161+ arch = os .getenv ("TRITON_OVERRIDE_ARCH" )
162+ if arch is not None :
163+ self .sw_capability = self ._parse_arch (arch )
154164 # HW Capability is used to determine the binary format
155165 self .hw_capability = target .arch
156- assert isinstance (self .capability , int )
166+ assert isinstance (self .hw_capability , int )
167+ assert isinstance (self .sw_capability , int )
157168 self .binary_ext = "cubin"
158169
159170 def parse_options (self , opts ) -> Any :
160171 args = {k : opts [k ] for k in CUDAOptions .__dataclass_fields__ .keys () if k in opts }
161172 if "supported_fp8_dtypes" not in args :
162173 supported_fp8_dtypes = set (CUDAOptions .supported_fp8_dtypes )
163- if self .capability >= 89 :
174+ if self .sw_capability >= 89 :
164175 supported_fp8_dtypes .add ("fp8e4nv" )
165176 args ["supported_fp8_dtypes" ] = tuple (sorted (supported_fp8_dtypes ))
166177
167178 if "deprecated_fp8_dtypes" not in args :
168- if self .capability >= 90 :
179+ if self .sw_capability >= 90 :
169180 args ["deprecated_fp8_dtypes" ] = ("fp8e4b15" , )
170181
171182 if "enable_fp_fusion" not in args :
172183 args ["enable_fp_fusion" ] = os .getenv ("TRITON_DEFAULT_FP_FUSION" , "1" ) == "1"
173184
174- if "override_nv_compute_capability" in args and args [ "override_nv_compute_capability" ] is not None :
175- self .capability = args ["override_nv_compute_capability" ]
176- args ["max_num_imprecise_acc_default" ] = 2 ** 30 if self .capability == 90 else 0
185+ if args . get ( "arch" , None ) is not None :
186+ self .sw_capability = self . _parse_arch ( args ["arch" ])
187+ args ["max_num_imprecise_acc_default" ] = 2 ** 30 if self .sw_capability == 90 else 0
177188 return CUDAOptions (** args )
178189
179190 def pack_metadata (self , metadata ):
@@ -190,7 +201,7 @@ def get_codegen_implementation(self):
190201 import triton .language .extra .cuda as cuda
191202 codegen_fns = {
192203 "convert_custom_types" :
193- cuda .convert_custom_float8_sm80 if self .capability >= 80 else cuda .convert_custom_float8_sm70 ,
204+ cuda .convert_custom_float8_sm80 if self .sw_capability >= 80 else cuda .convert_custom_float8_sm70 ,
194205 "min_dot_size" : min_dot_size (self .target )
195206 }
196207 return codegen_fns
@@ -401,12 +412,12 @@ def make_cubin(src, metadata, opt, capability):
401412
402413 def add_stages (self , stages , options ):
403414 stages ["ttir" ] = lambda src , metadata : self .make_ttir (src , metadata , options )
404- stages ["ttgir" ] = lambda src , metadata : self .make_ttgir (src , metadata , options , self .capability )
405- stages ["llir" ] = lambda src , metadata : self .make_llir (src , metadata , options , self .capability )
415+ stages ["ttgir" ] = lambda src , metadata : self .make_ttgir (src , metadata , options , self .sw_capability )
416+ stages ["llir" ] = lambda src , metadata : self .make_llir (src , metadata , options , self .sw_capability )
406417 stages ["ptx" ] = lambda src , metadata : self .make_ptx (src , metadata , options , self .hw_capability )
407418 stages ["cubin" ] = lambda src , metadata : self .make_cubin (src , metadata , options , self .hw_capability )
408419
409420 @functools .lru_cache ()
410421 def hash (self ):
411422 version = get_ptxas_version ()
412- return f'{ version } -{ self .capability } '
423+ return f'{ version } -{ self .sw_capability } - { self . hw_capability } '
0 commit comments