Skip to content

Commit 467b73d

Browse files
committed
Set pytorch cuda arch list explicitly to get rid of warning.
1 parent 5b23052 commit 467b73d

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

qmb/hamiltonian.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,27 @@ class Hamiltonian:
1616

1717
_hamiltonian_module: dict[tuple[str, int, int], object] = {}
1818

19+
@classmethod
20+
def _set_torch_cuda_arch_list(cls) -> None:
21+
"""
22+
Set the CUDA architecture list for PyTorch to use when compiling the PyTorch extensions.
23+
"""
24+
if not torch.cuda.is_available():
25+
return
26+
if "TORCH_CUDA_ARCH_LIST" in os.environ:
27+
return
28+
supported_sm = [int(arch[3:]) for arch in torch.cuda.get_arch_list() if arch.startswith("sm_")]
29+
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
30+
arch_list = set()
31+
for i in range(torch.cuda.device_count()):
32+
capability = min(max_supported_sm, torch.cuda.get_device_capability(i))
33+
arch = f'{capability[0]}.{capability[1]}'
34+
arch_list.add(arch)
35+
os.environ["TORCH_CUDA_ARCH_LIST"] = ";".join(sorted(arch_list))
36+
1937
@classmethod
2038
def _load_module(cls, device_type: str = "declaration", n_qubytes: int = 0, particle_cut: int = 0) -> object:
39+
cls._set_torch_cuda_arch_list()
2140
if device_type != "declaration":
2241
cls._load_module("declaration", n_qubytes, particle_cut) # Ensure the declaration module is loaded first
2342
key = (device_type, n_qubytes, particle_cut)

0 commit comments

Comments
 (0)