@@ -31,16 +31,16 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m,
3131 return check_dot_compatibility
3232
3333
34- def get_ptxas () -> knobs .NvidiaTool :
35- return knobs .nvidia .ptxas
34+ def get_ptxas (arch : int ) -> knobs .NvidiaTool :
35+ return knobs .nvidia .ptxas_blackwell if arch >= 100 else knobs . nvidia . ptxas
3636
3737
3838@functools .lru_cache ()
39- def get_ptxas_version ():
39+ def get_ptxas_version (arch : int = 80 ):
4040 mock_ver = knobs .nvidia .mock_ptx_version
4141 if mock_ver is not None :
4242 return mock_ver # This is not really a version of ptxas, but it is good enough for testing
43- version = subprocess .check_output ([get_ptxas ().path , "--version" ]).decode ("utf-8" )
43+ version = subprocess .check_output ([get_ptxas (arch ).path , "--version" ]).decode ("utf-8" )
4444 return version
4545
4646
@@ -71,7 +71,7 @@ def ptx_get_version(cuda_version) -> int:
7171def get_ptx_version_from_options (options , arch : int ):
7272 ptx_version = options .ptx_version
7373 if ptx_version is None :
74- cuda_version = get_ptxas ().version
74+ cuda_version = get_ptxas (arch ).version
7575 ptx_version = ptx_get_version (cuda_version )
7676 return ptx_version
7777
@@ -465,7 +465,7 @@ def make_ptx(self, src, metadata, opt, capability):
465465 return ret
466466
467467 def make_cubin (self , src , metadata , opt , capability ):
468- ptxas = get_ptxas ().path
468+ ptxas = get_ptxas (self . target . arch ).path
469469 with tempfile .NamedTemporaryFile (delete = False , mode = 'w' , suffix = '.ptx' ) as fsrc , \
470470 tempfile .NamedTemporaryFile (delete = False , mode = 'r' , suffix = '.log' ) as flog :
471471 fsrc .write (src )
@@ -555,5 +555,5 @@ def add_stages(self, stages, options, language):
555555
556556 @functools .lru_cache ()
557557 def hash (self ):
558- version = get_ptxas_version ()
558+ version = get_ptxas_version (self . target . arch )
559559 return f'{ version } -{ self .target .arch } '
0 commit comments