Skip to content

Commit cebe374

Browse files
py-torch: add cusparselt variant (spack#1717)
* py-torch: add cusparselt variant * py-nvidia-cusparselt: delete package * [@spackbot] updating style on behalf of thomas-bouvier * nvidia-cusparselt: disable cuda 13 for now * [@spackbot] updating style on behalf of thomas-bouvier * cusparselt: address reviews --------- Co-authored-by: thomas-bouvier <[email protected]>
1 parent 8f50ef9 commit cebe374

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright Spack Project Developers. See COPYRIGHT file for details.
2+
#
3+
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
4+
5+
import platform
6+
7+
from spack.package import *
8+
9+
10+
class Cusparselt(Package):
11+
"""A high-performance CUDA library dedicated to general matrix-matrix operations
12+
in which at least one operand is a structured sparse matrix with 50% sparsity ratio."""
13+
14+
homepage = "https://docs.nvidia.com/cuda/cusparselt/"
15+
16+
skip_version_audit = ["platform=darwin", "platform=windows"]
17+
18+
maintainers("thomas-bouvier")
19+
20+
system = platform.system().lower()
21+
arch = platform.machine()
22+
if "linux" in system and arch == "x86_64":
23+
# version(
24+
# "0.8.1-cuda130",
25+
# sha256="82dd3e5ebc199a27011f58857a80cd825e77bba634ab2286ba3d4e13115db89a",
26+
# url="https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-x86_64/libcusparse_lt-linux-x86_64-0.8.1.1_cuda13-archive.tar.xz",
27+
# )
28+
version(
29+
"0.8.1-cuda120",
30+
sha256="b34272e683e9f798435af05dc124657d1444cd0e13802c3d2f3152e31cd898a3",
31+
url="https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-x86_64/libcusparse_lt-linux-x86_64-0.8.1.1_cuda12-archive.tar.xz",
32+
)
33+
elif "linux" in system and arch == "aarch64":
34+
# version(
35+
# "0.8.1-cuda130",
36+
# sha256="0fcf5808f66c71f755b4a73af2e955292e4334fec6a851eea1ac2e20878602b7",
37+
# url="https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-aarch64/libcusparse_lt-linux-aarch64-0.8.1.1_cuda13-archive.tar.xz",
38+
# )
39+
version(
40+
"0.8.1-cuda120",
41+
sha256="5426a897c73a9b98a83c4e132d15abc63dc4a00f7e38266e7b82c42cd58a01e1",
42+
url="https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-aarch64/libcusparse_lt-linux-aarch64-0.8.1.1_cuda12-archive.tar.xz",
43+
)
44+
45+
# cuda130_versions = ("@0.8.1-cuda130",)
46+
cuda120_versions = ("@0.8.1-cuda120",)
47+
48+
# for v in cuda130_versions:
49+
# depends_on("cuda@13", when=v, type=("build", "run"))
50+
for v in cuda120_versions:
51+
depends_on("cuda@12", when=v, type=("build", "run"))
52+
53+
depends_on("c", type="build")
54+
depends_on("cxx", type="build")
55+
56+
# Installation instructions
57+
def install(self, spec, prefix):
58+
# Create installation directories
59+
mkdirp(prefix.lib)
60+
mkdirp(prefix.include)
61+
62+
# Copy library files
63+
install_tree("lib", prefix.lib)
64+
65+
# Copy header files
66+
install_tree("include", prefix.include)

repos/spack_repo/builtin/packages/py_torch/package.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
110110
_desc = "Build the flash_attention kernel for scaled dot product attention"
111111
variant("flash_attention", default=True, description=_desc, when="@1.13:+cuda")
112112
variant("flash_attention", default=True, description=_desc, when="@1.13:+rocm")
113+
variant("cusparselt", default=True, description="Use NVIDIA cuSPARSELt", when="@2.1: +cuda")
113114
# py-torch has strict dependencies on old protobuf/py-protobuf versions that
114115
# cause problems with other packages that require newer versions of protobuf
115116
# and py-protobuf --> provide an option to use the internal/vendored protobuf.
@@ -314,6 +315,7 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
314315
depends_on("ucc", when="+ucc")
315316
depends_on("ucx", when="+ucc")
316317
depends_on("mkl", when="+mkldnn")
318+
depends_on("cusparselt", when="+cusparselt")
317319

318320
# Test dependencies
319321
with default_args(type="test"):
@@ -678,6 +680,7 @@ def enable_or_disable(variant, keyword="USE", var=None):
678680
env.set("CUDNN_INCLUDE_DIR", self.spec["cudnn"].prefix.include)
679681
env.set("CUDNN_LIBRARY", self.spec["cudnn"].libs[0])
680682

683+
enable_or_disable("cusparselt")
681684
enable_or_disable("fbgemm")
682685
enable_or_disable("kineto")
683686
enable_or_disable("magma")

0 commit comments

Comments
 (0)