Skip to content

Commit 49cb1b5

Browse files
hzhangxyzwindy-pig
andcommitted
Add support for conditional compilation of cpu and cuda kernels.
Co-authored-by: Junkai Xu <[email protected]>
1 parent 4f01926 commit 49cb1b5

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

qmb/hamiltonian.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,40 @@ class Hamiltonian:
1414
The Hamiltonian type, which stores the Hamiltonian and processes iteration over each term in the Hamiltonian for given configurations.
1515
"""
1616

17-
_hamiltonian_module: dict[int, object] = {}
17+
_hamiltonian_module: dict[tuple[str, int, int], object] = {}
1818

1919
@classmethod
20-
def _load_module(cls, n_qubytes: int = 0, particle_cut: int = 0) -> object:
21-
if n_qubytes not in cls._hamiltonian_module:
22-
name = "qmb_hamiltonian" if n_qubytes == 0 else f"qmb_hamiltonian_{n_qubytes}_{particle_cut}"
23-
build_directory = platformdirs.user_cache_path("qmb", "kclab") / name
20+
def _load_module(cls, device_type: str = "declaration", n_qubytes: int = 0, particle_cut: int = 0) -> object:
21+
if device_type != "declaration":
22+
cls._load_module("declaration", n_qubytes, particle_cut) # Ensure the declaration module is loaded first
23+
key = (device_type, n_qubytes, particle_cut)
24+
is_prepare = key == ("declaration", 0, 0)
25+
name = "qmb_hamiltonian" if is_prepare else f"qmb_hamiltonian_{n_qubytes}_{particle_cut}"
26+
if key not in cls._hamiltonian_module:
27+
build_directory = platformdirs.user_cache_path("qmb", "kclab") / name / device_type
2428
build_directory.mkdir(parents=True, exist_ok=True)
2529
folder = os.path.dirname(__file__)
26-
cls._hamiltonian_module[n_qubytes] = torch.utils.cpp_extension.load(
30+
match device_type:
31+
case "declaration":
32+
sources = [f"{folder}/_hamiltonian.cpp"]
33+
case "cpu":
34+
sources = [f"{folder}/_hamiltonian_cpu.cpp"]
35+
case "cuda":
36+
sources = [f"{folder}/_hamiltonian_cuda.cu"]
37+
case _:
38+
raise ValueError("Unsupported device type")
39+
cls._hamiltonian_module[key] = torch.utils.cpp_extension.load(
2740
name=name,
28-
sources=[
29-
f"{folder}/_hamiltonian.cpp",
30-
f"{folder}/_hamiltonian_cpu.cpp",
31-
f"{folder}/_hamiltonian_cuda.cu",
32-
],
33-
is_python_module=n_qubytes == 0,
41+
sources=sources,
42+
is_python_module=is_prepare,
3443
extra_cflags=["-O3", "-ffast-math", "-march=native", f"-DN_QUBYTES={n_qubytes}", f"-DPARTICLE_CUT={particle_cut}", "-std=c++20"],
3544
extra_cuda_cflags=["-O3", "--use_fast_math", f"-DN_QUBYTES={n_qubytes}", f"-DPARTICLE_CUT={particle_cut}", "-std=c++20"],
3645
build_directory=build_directory,
3746
)
38-
if n_qubytes == 0: # pylint: disable=no-else-return
39-
return cls._hamiltonian_module[n_qubytes]
47+
if is_prepare: # pylint: disable=no-else-return
48+
return cls._hamiltonian_module[key]
4049
else:
41-
return getattr(torch.ops, f"qmb_hamiltonian_{n_qubytes}_{particle_cut}")
50+
return getattr(torch.ops, name)
4251

4352
@classmethod
4453
def _prepare(cls, hamiltonian: dict[tuple[tuple[int, int], ...], complex]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -98,7 +107,7 @@ def apply_within(
98107
"""
99108
self._prepare_data(configs_i.device)
100109
_apply_within: typing.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
101-
_apply_within = getattr(self._load_module(configs_i.size(1), self.particle_cut), "apply_within")
110+
_apply_within = getattr(self._load_module(configs_i.device.type, configs_i.size(1), self.particle_cut), "apply_within")
102111
psi_j = torch.view_as_complex(_apply_within(configs_i, torch.view_as_real(psi_i), configs_j, self.site, self.kind, self.coef))
103112
return psi_j
104113

@@ -133,7 +142,7 @@ def find_relative(
133142
configs_exclude = configs_i
134143
self._prepare_data(configs_i.device)
135144
_find_relative: typing.Callable[[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
136-
_find_relative = getattr(self._load_module(configs_i.size(1), self.particle_cut), "find_relative")
145+
_find_relative = getattr(self._load_module(configs_i.device.type, configs_i.size(1), self.particle_cut), "find_relative")
137146
configs_j = _find_relative(configs_i, torch.view_as_real(psi_i), count_selected, self.site, self.kind, self.coef, configs_exclude)
138147
return configs_j
139148

@@ -153,6 +162,6 @@ def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
153162
"""
154163
self._prepare_data(configs.device)
155164
_single_relative: typing.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
156-
_single_relative = getattr(self._load_module(configs.size(1), self.particle_cut), "single_relative")
165+
_single_relative = getattr(self._load_module(configs.device.type, configs.size(1), self.particle_cut), "single_relative")
157166
configs_result = _single_relative(configs, self.site, self.kind, self.coef)
158167
return configs_result

0 commit comments

Comments
 (0)