Skip to content

Commit d37d09e

Browse files
committed
Add a prototype for RL based subspace diagonalization.
This PR add a prototype for reinforcement learning based subspace diagonalization. It is mainly comprised of: 1. Add a neural network named crossmlp. 2. Add `single_relative` operator and its kernel for cuda 3. Add the interface named `rldiag`. The neural network `crossmlp` is just mlp for a single config + a layer across different configs, which was dicussed in #16. The single_relative sample a single relative configuration for each configuration in the pool, we do not use the previous find_relative, because we want exactly one relative configuration for each. BTW, the result of this function does not ensure uniqueness. The rldiag interface implements the algorithm proposed in #13 . PR tracking at: USTC-KnowledgeComputingLab/qmb#14
2 parents 371c83d + 4346d91 commit d37d09e

File tree

10 files changed

+753
-0
lines changed

10 files changed

+753
-0
lines changed

qmb/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from . import ising as _ # type: ignore[no-redef]
1212
from . import vmc as _ # type: ignore[no-redef]
1313
from . import imag as _ # type: ignore[no-redef]
14+
from . import rldiag as _ # type: ignore[no-redef]
1415
from . import precompile as _ # type: ignore[no-redef]
1516
from . import list_loss as _ # type: ignore[no-redef]
1617
from . import chop_imag as _ # type: ignore[no-redef]

qmb/_hamiltonian.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ TORCH_LIBRARY_FRAGMENT(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), m) {
8484
m.def("apply_within(Tensor configs_i, Tensor psi_i, Tensor configs_j, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8585
m.def("find_relative(Tensor configs_i, Tensor psi_i, int count_selected, Tensor site, Tensor kind, Tensor coef, Tensor configs_exclude) -> Tensor"
8686
);
87+
m.def("single_relative(Tensor configs, Tensor site, Tensor kind, Tensor coef) -> Tensor");
8788
}
8889
#undef QMB_LIBRARY
8990
#undef QMB_LIBRARY_HELPER

qmb/_hamiltonian_cuda.cu

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/cuda/Exceptions.h>
22
#include <c10/cuda/CUDAStream.h>
33
#include <cuda_runtime.h>
4+
#include <curand_kernel.h>
45
#include <thrust/sort.h>
56
#include <torch/extension.h>
67

@@ -670,6 +671,198 @@ auto find_relative_interface(
670671
return unique_nonzero_result_config;
671672
}
672673

674+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
675+
__device__ void single_relative_kernel(
676+
std::int64_t term_index,
677+
std::int64_t batch_index,
678+
std::int64_t term_number,
679+
std::int64_t batch_size,
680+
std::int64_t exclude_size,
681+
std::uint64_t seed,
682+
const std::array<std::int16_t, max_op_number>* site, // term_number
683+
const std::array<std::uint8_t, max_op_number>* kind, // term_number
684+
const std::array<double, 2>* coef, // term_number
685+
const std::array<std::uint8_t, n_qubytes>* configs, // batch_size
686+
const std::array<std::uint8_t, n_qubytes>* exclude_configs, // exclude_size
687+
std::array<std::uint8_t, n_qubytes>* result_configs, // batch_size
688+
double* score, // batch_size
689+
int* mutex // batch_size
690+
) {
691+
std::array<std::uint8_t, n_qubytes> current_configs = configs[batch_index];
692+
auto [success, parity] = hamiltonian_apply_kernel<max_op_number, n_qubytes, particle_cut>(
693+
/*current_configs=*/current_configs,
694+
/*term_index=*/term_index,
695+
/*batch_index=*/batch_index,
696+
/*site=*/site,
697+
/*kind=*/kind
698+
);
699+
700+
if (!success) {
701+
return;
702+
}
703+
success = true;
704+
std::int64_t low = 0;
705+
std::int64_t high = exclude_size - 1;
706+
std::int64_t mid = 0;
707+
auto compare = array_less<std::uint8_t, n_qubytes>();
708+
while (low <= high) {
709+
mid = (low + high) / 2;
710+
if (compare(current_configs, exclude_configs[mid])) {
711+
high = mid - 1;
712+
} else if (compare(exclude_configs[mid], current_configs)) {
713+
low = mid + 1;
714+
} else {
715+
success = false;
716+
break;
717+
}
718+
}
719+
if (!success) {
720+
return;
721+
}
722+
723+
// Efraimidis-Spirakis Algorithm is used here.
724+
auto weight = std::pow(coef[term_index][0] * coef[term_index][0] + coef[term_index][1] * coef[term_index][1], 0.5);
725+
curandState state;
726+
curand_init(seed, term_index, 0, &state);
727+
auto key = std::pow(curand_uniform_double(&state), 1.0 / weight);
728+
if (score[batch_index] < key) {
729+
mutex_lock(&mutex[batch_index]);
730+
if (score[batch_index] < key) {
731+
score[batch_index] = key;
732+
result_configs[batch_index] = current_configs;
733+
}
734+
mutex_unlock(&mutex[batch_index]);
735+
}
736+
}
737+
738+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
739+
__global__ void single_relative_kernel_interface(
740+
std::int64_t term_number,
741+
std::int64_t batch_size,
742+
std::int64_t exclude_size,
743+
std::uint64_t seed,
744+
const std::array<std::int16_t, max_op_number>* site, // term_number
745+
const std::array<std::uint8_t, max_op_number>* kind, // term_number
746+
const std::array<double, 2>* coef, // term_number
747+
const std::array<std::uint8_t, n_qubytes>* configs, // batch_size
748+
const std::array<std::uint8_t, n_qubytes>* exclude_configs, // exclude_size
749+
std::array<std::uint8_t, n_qubytes>* result_configs, // batch_size
750+
double* score, // batch_size
751+
int* mutex // batch_size
752+
) {
753+
std::int64_t term_index = blockIdx.x * blockDim.x + threadIdx.x;
754+
std::int64_t batch_index = blockIdx.y * blockDim.y + threadIdx.y;
755+
756+
if (term_index < term_number && batch_index < batch_size) {
757+
single_relative_kernel<max_op_number, n_qubytes, particle_cut>(
758+
/*term_index=*/term_index,
759+
/*batch_index=*/batch_index,
760+
/*term_number=*/term_number,
761+
/*batch_size=*/batch_size,
762+
/*exclude_size=*/exclude_size,
763+
/*seed=*/seed,
764+
/*site=*/site,
765+
/*kind=*/kind,
766+
/*coef=*/coef,
767+
/*configs=*/configs,
768+
/*exclude_configs=*/exclude_configs,
769+
/*result_configs=*/result_configs,
770+
/*score=*/score,
771+
/*mutex=*/mutex
772+
);
773+
}
774+
}
775+
776+
template<std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
777+
auto single_relative_interface(const torch::Tensor& configs, const torch::Tensor& site, const torch::Tensor& kind, const torch::Tensor& coef)
778+
-> torch::Tensor {
779+
std::int64_t device_id = configs.device().index();
780+
std::int64_t batch_size = configs.size(0);
781+
std::int64_t term_number = site.size(0);
782+
783+
TORCH_CHECK(configs.device().type() == torch::kCUDA, "configs must be on CUDA.")
784+
TORCH_CHECK(configs.device().index() == device_id, "configs must be on the same device as others.");
785+
TORCH_CHECK(configs.is_contiguous(), "configs must be contiguous.")
786+
TORCH_CHECK(configs.dtype() == torch::kUInt8, "configs must be uint8.")
787+
TORCH_CHECK(configs.dim() == 2, "configs must be 2D.")
788+
TORCH_CHECK(configs.size(0) == batch_size, "configs batch size must match the provided batch_size.");
789+
TORCH_CHECK(configs.size(1) == n_qubytes, "configs must have the same number of qubits as the provided n_qubytes.");
790+
791+
TORCH_CHECK(site.device().type() == torch::kCUDA, "site must be on CUDA.")
792+
TORCH_CHECK(site.device().index() == device_id, "site must be on the same device as others.");
793+
TORCH_CHECK(site.is_contiguous(), "site must be contiguous.")
794+
TORCH_CHECK(site.dtype() == torch::kInt16, "site must be int16.")
795+
TORCH_CHECK(site.dim() == 2, "site must be 2D.")
796+
TORCH_CHECK(site.size(0) == term_number, "site size must match the provided term_number.");
797+
TORCH_CHECK(site.size(1) == max_op_number, "site must match the provided max_op_number.");
798+
799+
TORCH_CHECK(kind.device().type() == torch::kCUDA, "kind must be on CUDA.")
800+
TORCH_CHECK(kind.device().index() == device_id, "kind must be on the same device as others.");
801+
TORCH_CHECK(kind.is_contiguous(), "kind must be contiguous.")
802+
TORCH_CHECK(kind.dtype() == torch::kUInt8, "kind must be uint8.")
803+
TORCH_CHECK(kind.dim() == 2, "kind must be 2D.")
804+
TORCH_CHECK(kind.size(0) == term_number, "kind size must match the provided term_number.");
805+
TORCH_CHECK(kind.size(1) == max_op_number, "kind must match the provided max_op_number.");
806+
807+
TORCH_CHECK(coef.device().type() == torch::kCUDA, "coef must be on CUDA.")
808+
TORCH_CHECK(coef.device().index() == device_id, "coef must be on the same device as others.");
809+
TORCH_CHECK(coef.is_contiguous(), "coef must be contiguous.")
810+
TORCH_CHECK(coef.dtype() == torch::kFloat64, "coef must be float64.")
811+
TORCH_CHECK(coef.dim() == 2, "coef must be 2D.")
812+
TORCH_CHECK(coef.size(0) == term_number, "coef size must match the provided term_number.");
813+
TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term.");
814+
815+
auto stream = at::cuda::getCurrentCUDAStream(device_id);
816+
auto policy = thrust::device.on(stream);
817+
818+
cudaDeviceProp prop;
819+
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id));
820+
std::int64_t max_threads_per_block = prop.maxThreadsPerBlock;
821+
822+
auto sorted_configs = configs.clone(torch::MemoryFormat::Contiguous);
823+
824+
thrust::sort(
825+
policy,
826+
reinterpret_cast<std::array<std::uint8_t, n_qubytes>*>(sorted_configs.data_ptr()),
827+
reinterpret_cast<std::array<std::uint8_t, n_qubytes>*>(sorted_configs.data_ptr()) + batch_size,
828+
array_less<std::uint8_t, n_qubytes>()
829+
);
830+
831+
auto seed_tensor = torch::randint(int64_t(0), int64_t(std::numeric_limits<std::int64_t>::max), {}, torch::TensorOptions().dtype(torch::kInt64));
832+
auto seed = *(int64_t*)(seed_tensor.data_ptr());
833+
834+
auto result_configs = torch::zeros({batch_size, n_qubytes}, torch::TensorOptions().dtype(torch::kUInt8).device(device, device_id));
835+
auto score = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat64).device(device, device_id))
836+
.fill_(-std::numeric_limits<double>::infinity());
837+
int* mutex;
838+
AT_CUDA_CHECK(cudaMalloc(&mutex, sizeof(int) * batch_size));
839+
AT_CUDA_CHECK(cudaMemset(mutex, 0, sizeof(int) * batch_size));
840+
841+
auto threads_per_block = dim3{1, max_threads_per_block >> 1}; // I don't know why, but need to divide by 2 to avoid errors
842+
auto num_blocks =
843+
dim3{(term_number + threads_per_block.x - 1) / threads_per_block.x, (batch_size + threads_per_block.y - 1) / threads_per_block.y};
844+
845+
single_relative_kernel_interface<max_op_number, n_qubytes, particle_cut><<<num_blocks, threads_per_block, 0, stream>>>(
846+
/*term_number=*/term_number,
847+
/*batch_size=*/batch_size,
848+
/*exclude_size=*/batch_size,
849+
/*seed=*/seed,
850+
/*site=*/reinterpret_cast<const std::array<std::int16_t, max_op_number>*>(site.data_ptr()),
851+
/*kind=*/reinterpret_cast<const std::array<std::uint8_t, max_op_number>*>(kind.data_ptr()),
852+
/*coef=*/reinterpret_cast<const std::array<double, 2>*>(coef.data_ptr()),
853+
/*configs=*/reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(configs.data_ptr()),
854+
/*exclude_configs=*/reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(sorted_configs.data_ptr()),
855+
/*result_configs=*/reinterpret_cast<std::array<std::uint8_t, n_qubytes>*>(result_configs.data_ptr()),
856+
/*score=*/reinterpret_cast<double*>(score.data_ptr()),
857+
/*mutex=*/mutex
858+
);
859+
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
860+
861+
AT_CUDA_CHECK(cudaFree(mutex));
862+
863+
return result_configs;
864+
}
865+
673866
#ifndef N_QUBYTES
674867
#define N_QUBYTES 0
675868
#endif
@@ -683,6 +876,7 @@ auto find_relative_interface(
683876
TORCH_LIBRARY_IMPL(QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CUDA, m) {
684877
m.impl("apply_within", apply_within_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
685878
m.impl("find_relative", find_relative_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
879+
m.impl("single_relative", single_relative_interface</*max_op_number=*/4, /*n_qubytes=*/N_QUBYTES, /*particle_cut=*/PARTICLE_CUT>);
686880
}
687881
#undef QMB_LIBRARY
688882
#undef QMB_LIBRARY_HELPER

qmb/crossmlp.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
This file implements a cross MLP network.
3+
"""
4+
5+
import typing
6+
import torch
7+
from .bitspack import unpack_int
8+
9+
10+
class FakeLinear(torch.nn.Module):
11+
"""
12+
A fake linear layer with zero input dimension to avoid PyTorch initialization warnings.
13+
"""
14+
15+
def __init__(self, dim_in: int, dim_out: int) -> None:
16+
super().__init__()
17+
assert dim_in == 0
18+
self.bias: torch.nn.Parameter = torch.nn.Parameter(torch.zeros([dim_out]))
19+
20+
def forward(self, x: torch.Tensor) -> torch.Tensor:
21+
"""
22+
Forward pass for the fake linear layer.
23+
"""
24+
batch, _ = x.shape
25+
return self.bias.view([1, -1]).expand([batch, -1])
26+
27+
28+
def select_linear_layer(dim_in: int, dim_out: int) -> torch.nn.Module:
29+
"""
30+
Selects between a fake linear layer and a standard one to avoid initialization warnings when dim_in is zero.
31+
"""
32+
if dim_in == 0: # pylint: disable=no-else-return
33+
return FakeLinear(dim_in, dim_out)
34+
else:
35+
return torch.nn.Linear(dim_in, dim_out)
36+
37+
38+
class MLP(torch.nn.Module):
39+
"""
40+
This module implements multiple layers MLP with given dim_input, dim_output and hidden_size.
41+
"""
42+
43+
def __init__(self, dim_input: int, dim_output: int, hidden_size: tuple[int, ...]) -> None:
44+
super().__init__()
45+
self.dim_input: int = dim_input
46+
self.dim_output: int = dim_output
47+
self.hidden_size: tuple[int, ...] = hidden_size
48+
self.depth: int = len(hidden_size)
49+
50+
dimensions: list[int] = [dim_input] + list(hidden_size) + [dim_output]
51+
linears: list[torch.nn.Module] = [select_linear_layer(i, j) for i, j in zip(dimensions[:-1], dimensions[1:])]
52+
modules: list[torch.nn.Module] = [module for linear in linears for module in (linear, torch.nn.SiLU())][:-1]
53+
self.model: torch.nn.Module = torch.nn.Sequential(*modules)
54+
55+
def forward(self, x: torch.Tensor) -> torch.Tensor:
56+
"""
57+
Forward pass for the MLP.
58+
"""
59+
return self.model(x)
60+
61+
62+
class WaveFunction(torch.nn.Module):
63+
"""
64+
The wave function for the cross MLP network.
65+
"""
66+
67+
# pylint: disable=too-many-instance-attributes
68+
69+
def __init__( # pylint: disable=too-many-arguments
70+
self,
71+
*,
72+
sites: int, # Number of qubits
73+
physical_dim: int, # Dimension of the physical space, which is always 2 for MLP
74+
is_complex: bool, # Indicates whether the wave function is complex-valued, which is always true for MLP
75+
embedding_hidden_size: tuple[int, ...], # Hidden layer sizes for embedding part
76+
embedding_size: int, # The dimension of the embedding
77+
momentum_hidden_size: tuple[int, ...], # Hidden layer sizes for momentum part
78+
momentum_count: int, # The number of max momentum order
79+
tail_hidden_size: tuple[int, ...], # Hidden layer size for tail part
80+
kind: typing.Literal[0, 1, 2], # Kind of the crossmlp forward function
81+
ordering: int | list[int], # Ordering of sites: +1 for normal order, -1 for reversed order, or a custom order list
82+
) -> None:
83+
super().__init__()
84+
self.sites: int = sites
85+
assert physical_dim == 2
86+
# This module is only used in reinforcement learning, which expects real values for the weights.
87+
assert is_complex == False # pylint: disable=singleton-comparison
88+
self.embedding_hidden_size: tuple[int, ...] = embedding_hidden_size
89+
self.embedding_size: int = embedding_size
90+
self.momentum_hidden_size: tuple[int, ...] = momentum_hidden_size
91+
self.momentum_count: int = momentum_count
92+
self.tail_hidden_size: tuple[int, ...] = tail_hidden_size
93+
self.kind: typing.Literal[0, 1, 2] = kind
94+
95+
self.emb = MLP(self.sites, self.embedding_size, self.embedding_hidden_size)
96+
self.momentum = torch.nn.ModuleList([MLP(self.embedding_size, self.embedding_size, momentum_hidden_size) for _ in range(self.momentum_count)])
97+
self.tail = MLP(self.embedding_size, 1, tail_hidden_size)
98+
99+
# Site Ordering Configuration
100+
# +1 for normal order, -1 for reversed order
101+
if isinstance(ordering, int) and ordering == +1:
102+
ordering = list(range(self.sites))
103+
if isinstance(ordering, int) and ordering == -1:
104+
ordering = list(reversed(range(self.sites)))
105+
self.ordering: torch.Tensor
106+
self.register_buffer('ordering', torch.tensor(ordering, dtype=torch.int64))
107+
self.ordering_reversed: torch.Tensor
108+
self.register_buffer('ordering_reversed', torch.scatter(torch.zeros(self.sites, dtype=torch.int64), 0, self.ordering, torch.arange(self.sites, dtype=torch.int64)))
109+
110+
# Dummy Parameter for Device and Dtype Retrieval
111+
# This parameter is used to infer the device and dtype of the model.
112+
self.dummy_param = torch.nn.Parameter(torch.empty(0))
113+
114+
@torch.jit.export
115+
def forward(self, x: torch.Tensor) -> torch.Tensor:
116+
"""
117+
Compute the wave function psi for the given configurations.
118+
"""
119+
dtype = self.dummy_param.dtype
120+
# x: batch_size * sites
121+
x = unpack_int(x, size=1, last_dim=self.sites)
122+
# Apply ordering
123+
x = torch.index_select(x, 1, self.ordering_reversed)
124+
# Dtype conversion
125+
x = x.to(dtype=dtype)
126+
127+
# emb: batch_size * embedding_size
128+
emb = self.emb(x)
129+
130+
if self.kind == 0:
131+
# x' = F(x - E[x]) + x
132+
for layer in self.momentum:
133+
new_emb = emb - emb.mean(dim=0, keepdim=True)
134+
new_emb = layer(new_emb)
135+
emb = emb + new_emb
136+
emb = emb / emb.norm(p=2, dim=1, keepdim=True)
137+
elif self.kind == 1:
138+
# x' = F(x) - E[F(x)] + x
139+
for layer in self.momentum:
140+
new_emb = layer(emb)
141+
new_emb = new_emb - new_emb.mean(dim=0, keepdim=True)
142+
emb = emb + new_emb
143+
emb = emb / emb.norm(p=2, dim=1, keepdim=True)
144+
elif self.kind == 2:
145+
# x' = (F(x) + x) - E [F(x) + x]
146+
for layer in self.momentum:
147+
new_emb = layer(emb)
148+
new_emb = new_emb + emb
149+
emb = new_emb - new_emb.mean(dim=0, keepdim=True)
150+
emb = emb / emb.norm(p=2, dim=1, keepdim=True)
151+
else:
152+
raise ValueError(f"Invalid kind: {self.kind}")
153+
154+
tail = self.tail(emb).squeeze(-1)
155+
return tail
156+
157+
@torch.jit.export
158+
def generate_unique(self, batch_size: int, block_num: int = 1) -> tuple[torch.Tensor, torch.Tensor, None, None]:
159+
"""
160+
This module does not support generating unique configurations.
161+
"""
162+
# This module is only used in reinforcement learning, which does not require configurations sampling.
163+
raise NotImplementedError("The generate_unique method is not implemented for this class.")

0 commit comments

Comments
 (0)