@@ -284,7 +284,8 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
284284 arch = {"x86_64" : "64" , "arm64" : "aarch64" , "aarch64" : "aarch64" }[platform .machine ()]
285285 except KeyError :
286286 arch = platform .machine ()
287- url = url_func (arch , version )
287+ supported = {"Linux" : "linux" , "Darwin" : "linux" }
288+ url = url_func (supported [system ], arch , version )
288289 tmp_path = os .path .join (triton_cache_path , "nvidia" , name ) # path to cache the download
289290 dst_path = os .path .join (base_dir , os .pardir , "third_party" , "nvidia" , "backend" , dst_path ) # final binary path
290291 platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux"
@@ -500,61 +501,62 @@ def get_platform_dependent_src_path(subdir):
500501
501502download_and_copy (
502503 name = "ptxas" , src_path = "bin/ptxas" , dst_path = "bin/ptxas" , variable = "TRITON_PTXAS_PATH" ,
503- version = NVIDIA_TOOLCHAIN_VERSION ["ptxas" ], url_func = lambda arch , version :
504+ version = NVIDIA_TOOLCHAIN_VERSION ["ptxas" ], url_func = lambda system , arch , version :
504505 ((lambda version_major , version_minor1 , version_minor2 :
505- f"https://anaconda.org/nvidia/cuda-nvcc-tools/{ version } /download/linux -{ arch } /cuda-nvcc-tools-{ version } -0.tar.bz2"
506+ f"https://anaconda.org/nvidia/cuda-nvcc-tools/{ version } /download/{ system } -{ arch } /cuda-nvcc-tools-{ version } -0.tar.bz2"
506507 if int (version_major ) >= 12 and int (version_minor1 ) >= 5 else
507- f"https://anaconda.org/nvidia/cuda-nvcc/{ version } /download/linux -{ arch } /cuda-nvcc-{ version } -0.tar.bz2" )
508+ f"https://anaconda.org/nvidia/cuda-nvcc/{ version } /download/{ system } -{ arch } /cuda-nvcc-{ version } -0.tar.bz2" )
508509 (* version .split ('.' ))))
509510download_and_copy (
510511 name = "cuobjdump" ,
511512 src_path = "bin/cuobjdump" ,
512513 dst_path = "bin/cuobjdump" ,
513514 variable = "TRITON_CUOBJDUMP_PATH" ,
514515 version = NVIDIA_TOOLCHAIN_VERSION ["cuobjdump" ],
515- url_func = lambda arch , version :
516- f"https://anaconda.org/nvidia/cuda-cuobjdump/{ version } /download/linux -{ arch } /cuda-cuobjdump-{ version } -0.tar.bz2" ,
516+ url_func = lambda system , arch , version :
517+ f"https://anaconda.org/nvidia/cuda-cuobjdump/{ version } /download/{ system } -{ arch } /cuda-cuobjdump-{ version } -0.tar.bz2" ,
517518)
518519download_and_copy (
519520 name = "nvdisasm" ,
520521 src_path = "bin/nvdisasm" ,
521522 dst_path = "bin/nvdisasm" ,
522523 variable = "TRITON_NVDISASM_PATH" ,
523524 version = NVIDIA_TOOLCHAIN_VERSION ["nvdisasm" ],
524- url_func = lambda arch , version :
525- f"https://anaconda.org/nvidia/cuda-nvdisasm/{ version } /download/linux -{ arch } /cuda-nvdisasm-{ version } -0.tar.bz2" ,
525+ url_func = lambda system , arch , version :
526+ f"https://anaconda.org/nvidia/cuda-nvdisasm/{ version } /download/{ system } -{ arch } /cuda-nvdisasm-{ version } -0.tar.bz2" ,
526527)
527528download_and_copy (
528529 name = "cudacrt" , src_path = get_platform_dependent_src_path ("include" ), dst_path = "include" ,
529- variable = "TRITON_CUDACRT_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cudacrt" ], url_func = lambda arch , version :
530+ variable = "TRITON_CUDACRT_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cudacrt" ], url_func = lambda system , arch , version :
530531 ((lambda version_major , version_minor1 , version_minor2 :
531- f"https://anaconda.org/nvidia/cuda-crt-dev_linux -{ arch } /{ version } /download/noarch/cuda-crt-dev_linux -{ arch } -{ version } -0.tar.bz2"
532+ f"https://anaconda.org/nvidia/cuda-crt-dev_ { system } -{ arch } /{ version } /download/noarch/cuda-crt-dev_ { system } -{ arch } -{ version } -0.tar.bz2"
532533 if int (version_major ) >= 12 and int (version_minor1 ) >= 5 else
533- f"https://anaconda.org/nvidia/cuda-nvcc/{ version } /download/linux -{ arch } /cuda-nvcc-{ version } -0.tar.bz2" )
534+ f"https://anaconda.org/nvidia/cuda-nvcc/{ version } /download/{ system } -{ arch } /cuda-nvcc-{ version } -0.tar.bz2" )
534535 (* version .split ('.' ))))
535536download_and_copy (
536537 name = "cudart" , src_path = get_platform_dependent_src_path ("include" ), dst_path = "include" ,
537- variable = "TRITON_CUDART_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cudart" ], url_func = lambda arch , version :
538+ variable = "TRITON_CUDART_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cudart" ], url_func = lambda system , arch , version :
538539 ((lambda version_major , version_minor1 , version_minor2 :
539- f"https://anaconda.org/nvidia/cuda-cudart-dev_linux -{ arch } /{ version } /download/noarch/cuda-cudart-dev_linux -{ arch } -{ version } -0.tar.bz2"
540+ f"https://anaconda.org/nvidia/cuda-cudart-dev_ { system } -{ arch } /{ version } /download/noarch/cuda-cudart-dev_ { system } -{ arch } -{ version } -0.tar.bz2"
540541 if int (version_major ) >= 12 and int (version_minor1 ) >= 5 else
541- f"https://anaconda.org/nvidia/cuda-cudart-dev/{ version } /download/linux -{ arch } /cuda-cudart-dev-{ version } -0.tar.bz2"
542+ f"https://anaconda.org/nvidia/cuda-cudart-dev/{ version } /download/{ system } -{ arch } /cuda-cudart-dev-{ version } -0.tar.bz2"
542543 )(* version .split ('.' ))))
543544download_and_copy (
544545 name = "cupti" , src_path = get_platform_dependent_src_path ("include" ), dst_path = "include" ,
545- variable = "TRITON_CUPTI_INCLUDE_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cupti" ], url_func = lambda arch , version :
546+ variable = "TRITON_CUPTI_INCLUDE_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cupti" ],
547+ url_func = lambda system , arch , version :
546548 ((lambda version_major , version_minor1 , version_minor2 :
547- f"https://anaconda.org/nvidia/cuda-cupti-dev/{ version } /download/linux -{ arch } /cuda-cupti-dev-{ version } -0.tar.bz2"
549+ f"https://anaconda.org/nvidia/cuda-cupti-dev/{ version } /download/{ system } -{ arch } /cuda-cupti-dev-{ version } -0.tar.bz2"
548550 if int (version_major ) >= 12 and int (version_minor1 ) >= 5 else
549- f"https://anaconda.org/nvidia/cuda-cupti/{ version } /download/linux -{ arch } /cuda-cupti-{ version } -0.tar.bz2" )
551+ f"https://anaconda.org/nvidia/cuda-cupti/{ version } /download/{ system } -{ arch } /cuda-cupti-{ version } -0.tar.bz2" )
550552 (* version .split ('.' ))))
551553download_and_copy (
552554 name = "cupti" , src_path = get_platform_dependent_src_path ("lib" ), dst_path = "lib/cupti" ,
553- variable = "TRITON_CUPTI_LIB_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cupti" ], url_func = lambda arch , version :
555+ variable = "TRITON_CUPTI_LIB_PATH" , version = NVIDIA_TOOLCHAIN_VERSION ["cupti" ], url_func = lambda system , arch , version :
554556 ((lambda version_major , version_minor1 , version_minor2 :
555- f"https://anaconda.org/nvidia/cuda-cupti-dev/{ version } /download/linux -{ arch } /cuda-cupti-dev-{ version } -0.tar.bz2"
557+ f"https://anaconda.org/nvidia/cuda-cupti-dev/{ version } /download/{ system } -{ arch } /cuda-cupti-dev-{ version } -0.tar.bz2"
556558 if int (version_major ) >= 12 and int (version_minor1 ) >= 5 else
557- f"https://anaconda.org/nvidia/cuda-cupti/{ version } /download/linux -{ arch } /cuda-cupti-{ version } -0.tar.bz2" )
559+ f"https://anaconda.org/nvidia/cuda-cupti/{ version } /download/{ system } -{ arch } /cuda-cupti-{ version } -0.tar.bz2" )
558560 (* version .split ('.' ))))
559561
560562backends = [* BackendInstaller .copy (["intel" , "nvidia" , "amd" ]), * BackendInstaller .copy_externals ()]
0 commit comments