@@ -14,31 +14,40 @@ class Hamiltonian:
14
14
The Hamiltonian type, which stores the Hamiltonian and processes iteration over each term in the Hamiltonian for given configurations.
15
15
"""
16
16
17
- _hamiltonian_module : dict [int , object ] = {}
17
+ _hamiltonian_module : dict [tuple [ str , int , int ] , object ] = {}
18
18
19
19
@classmethod
20
- def _load_module (cls , n_qubytes : int = 0 , particle_cut : int = 0 ) -> object :
21
- if n_qubytes not in cls ._hamiltonian_module :
22
- name = "qmb_hamiltonian" if n_qubytes == 0 else f"qmb_hamiltonian_{ n_qubytes } _{ particle_cut } "
23
- build_directory = platformdirs .user_cache_path ("qmb" , "kclab" ) / name
20
+ def _load_module (cls , device_type : str = "declaration" , n_qubytes : int = 0 , particle_cut : int = 0 ) -> object :
21
+ if device_type != "declaration" :
22
+ cls ._load_module ("declaration" , n_qubytes , particle_cut ) # Ensure the declaration module is loaded first
23
+ key = (device_type , n_qubytes , particle_cut )
24
+ is_prepare = key == ("declaration" , 0 , 0 )
25
+ name = "qmb_hamiltonian" if is_prepare else f"qmb_hamiltonian_{ n_qubytes } _{ particle_cut } "
26
+ if key not in cls ._hamiltonian_module :
27
+ build_directory = platformdirs .user_cache_path ("qmb" , "kclab" ) / name / device_type
24
28
build_directory .mkdir (parents = True , exist_ok = True )
25
29
folder = os .path .dirname (__file__ )
26
- cls ._hamiltonian_module [n_qubytes ] = torch .utils .cpp_extension .load (
30
+ match device_type :
31
+ case "declaration" :
32
+ sources = [f"{ folder } /_hamiltonian.cpp" ]
33
+ case "cpu" :
34
+ sources = [f"{ folder } /_hamiltonian_cpu.cpp" ]
35
+ case "cuda" :
36
+ sources = [f"{ folder } /_hamiltonian_cuda.cu" ]
37
+ case _:
38
+ raise ValueError ("Unsupported device type" )
39
+ cls ._hamiltonian_module [key ] = torch .utils .cpp_extension .load (
27
40
name = name ,
28
- sources = [
29
- f"{ folder } /_hamiltonian.cpp" ,
30
- f"{ folder } /_hamiltonian_cpu.cpp" ,
31
- f"{ folder } /_hamiltonian_cuda.cu" ,
32
- ],
33
- is_python_module = n_qubytes == 0 ,
41
+ sources = sources ,
42
+ is_python_module = is_prepare ,
34
43
extra_cflags = ["-O3" , "-ffast-math" , "-march=native" , f"-DN_QUBYTES={ n_qubytes } " , f"-DPARTICLE_CUT={ particle_cut } " , "-std=c++20" ],
35
44
extra_cuda_cflags = ["-O3" , "--use_fast_math" , f"-DN_QUBYTES={ n_qubytes } " , f"-DPARTICLE_CUT={ particle_cut } " , "-std=c++20" ],
36
45
build_directory = build_directory ,
37
46
)
38
- if n_qubytes == 0 : # pylint: disable=no-else-return
39
- return cls ._hamiltonian_module [n_qubytes ]
47
+ if is_prepare : # pylint: disable=no-else-return
48
+ return cls ._hamiltonian_module [key ]
40
49
else :
41
- return getattr (torch .ops , f"qmb_hamiltonian_ { n_qubytes } _ { particle_cut } " )
50
+ return getattr (torch .ops , name )
42
51
43
52
@classmethod
44
53
def _prepare (cls , hamiltonian : dict [tuple [tuple [int , int ], ...], complex ]) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
@@ -98,7 +107,7 @@ def apply_within(
98
107
"""
99
108
self ._prepare_data (configs_i .device )
100
109
_apply_within : typing .Callable [[torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ], torch .Tensor ]
101
- _apply_within = getattr (self ._load_module (configs_i .size (1 ), self .particle_cut ), "apply_within" )
110
+ _apply_within = getattr (self ._load_module (configs_i .device . type , configs_i . size (1 ), self .particle_cut ), "apply_within" )
102
111
psi_j = torch .view_as_complex (_apply_within (configs_i , torch .view_as_real (psi_i ), configs_j , self .site , self .kind , self .coef ))
103
112
return psi_j
104
113
@@ -133,7 +142,7 @@ def find_relative(
133
142
configs_exclude = configs_i
134
143
self ._prepare_data (configs_i .device )
135
144
_find_relative : typing .Callable [[torch .Tensor , torch .Tensor , int , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ], torch .Tensor ]
136
- _find_relative = getattr (self ._load_module (configs_i .size (1 ), self .particle_cut ), "find_relative" )
145
+ _find_relative = getattr (self ._load_module (configs_i .device . type , configs_i . size (1 ), self .particle_cut ), "find_relative" )
137
146
configs_j = _find_relative (configs_i , torch .view_as_real (psi_i ), count_selected , self .site , self .kind , self .coef , configs_exclude )
138
147
return configs_j
139
148
@@ -153,6 +162,6 @@ def single_relative(self, configs: torch.Tensor) -> torch.Tensor:
153
162
"""
154
163
self ._prepare_data (configs .device )
155
164
_single_relative : typing .Callable [[torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ], torch .Tensor ]
156
- _single_relative = getattr (self ._load_module (configs .size (1 ), self .particle_cut ), "single_relative" )
165
+ _single_relative = getattr (self ._load_module (configs .device . type , configs . size (1 ), self .particle_cut ), "single_relative" )
157
166
configs_result = _single_relative (configs , self .site , self .kind , self .coef )
158
167
return configs_result
0 commit comments