Skip to content

Commit 3b3c852

Browse files
authored
[PTXAS] Upgrade ptxas to 12.9.86 for blackwell (#8476)
Keep the old ptxas for pre-blackwell due to functional regressions
1 parent 273649e commit 3b3c852

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ cmake-build-*
6565
cuobjdump
6666
nvdisasm
6767
ptxas
68+
ptxas-blackwell
6869

6970
# Third-party include
7071
third_party/nvidia/backend/include

cmake/nvidia-toolchain-version.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"ptxas-blackwell": "12.9.86",
23
"ptxas": "12.8.93",
34
"cuobjdump": "12.8.55",
45
"nvdisasm": "12.8.55",

python/triton/knobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ class nvidia_knobs(base_knobs):
488488
cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump")
489489
nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm")
490490
ptxas: env_nvidia_tool = env_nvidia_tool("ptxas")
491+
ptxas_blackwell: env_nvidia_tool = env_nvidia_tool("ptxas-blackwell")
491492

492493
dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
493494
disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")

setup.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,17 @@ def download_and_copy_dependencies():
541541
url_func=lambda system, arch, version:
542542
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
543543
)
544+
545+
# We download a separate ptxas for blackwell, since there are some bugs when using it for hopper
546+
download_and_copy(
547+
name="nvcc",
548+
src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}",
549+
dst_path="bin/ptxas-blackwell",
550+
variable="TRITON_PTXAS_BLACKWELL_PATH",
551+
version=NVIDIA_TOOLCHAIN_VERSION["ptxas-blackwell"],
552+
url_func=lambda system, arch, version:
553+
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
554+
)
544555
download_and_copy(
545556
name="cuobjdump",
546557
src_func=lambda system, arch, version:

third_party/nvidia/backend/compiler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m,
3131
return check_dot_compatibility
3232

3333

34-
def get_ptxas() -> knobs.NvidiaTool:
35-
return knobs.nvidia.ptxas
34+
def get_ptxas(arch: int) -> knobs.NvidiaTool:
35+
return knobs.nvidia.ptxas_blackwell if arch >= 100 else knobs.nvidia.ptxas
3636

3737

3838
@functools.lru_cache()
39-
def get_ptxas_version():
39+
def get_ptxas_version(arch: int = 80):
4040
mock_ver = knobs.nvidia.mock_ptx_version
4141
if mock_ver is not None:
4242
return mock_ver # This is not really a version of ptxas, but it is good enough for testing
43-
version = subprocess.check_output([get_ptxas().path, "--version"]).decode("utf-8")
43+
version = subprocess.check_output([get_ptxas(arch).path, "--version"]).decode("utf-8")
4444
return version
4545

4646

@@ -71,7 +71,7 @@ def ptx_get_version(cuda_version) -> int:
7171
def get_ptx_version_from_options(options, arch: int):
7272
ptx_version = options.ptx_version
7373
if ptx_version is None:
74-
cuda_version = get_ptxas().version
74+
cuda_version = get_ptxas(arch).version
7575
ptx_version = ptx_get_version(cuda_version)
7676
return ptx_version
7777

@@ -465,7 +465,7 @@ def make_ptx(self, src, metadata, opt, capability):
465465
return ret
466466

467467
def make_cubin(self, src, metadata, opt, capability):
468-
ptxas = get_ptxas().path
468+
ptxas = get_ptxas(self.target.arch).path
469469
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
470470
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
471471
fsrc.write(src)
@@ -555,5 +555,5 @@ def add_stages(self, stages, options, language):
555555

556556
@functools.lru_cache()
557557
def hash(self):
558-
version = get_ptxas_version()
558+
version = get_ptxas_version(self.target.arch)
559559
return f'{version}-{self.target.arch}'

0 commit comments

Comments
 (0)