Skip to content

Commit 522af85

Browse files
committed
Remove thrust from dependencies of the cpu kernel.
PR: USTC-KnowledgeComputingLab/qmb#46 Signed-off-by: Hao Zhang <[email protected]>
2 parents 93f162d + ccd20af commit 522af85

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

qmb/_hamiltonian_cpu.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <thrust/execution_policy.h>
2-
#include <thrust/sort.h>
31
#include <torch/extension.h>
42

53
namespace qmb_hamiltonian_cpu {
@@ -225,17 +223,20 @@ auto apply_within_interface(
225223
TORCH_CHECK(coef.size(0) == term_number, "coef size must match the provided term_number.");
226224
TORCH_CHECK(coef.size(1) == 2, "coef must contain 2 elements for each term.");
227225

228-
auto sorted_result_configs = result_configs.clone(torch::MemoryFormat::Contiguous);
229226
auto result_sort_index = torch::arange(result_batch_size, torch::TensorOptions().dtype(torch::kInt64).device(device, device_id));
230-
auto sorted_result_psi = torch::zeros({result_batch_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(device, device_id));
231227

232-
thrust::sort_by_key(
233-
thrust::host,
234-
reinterpret_cast<std::array<std::uint8_t, n_qubytes>*>(sorted_result_configs.data_ptr()),
235-
reinterpret_cast<std::array<std::uint8_t, n_qubytes>*>(sorted_result_configs.data_ptr()) + result_batch_size,
228+
std::sort(
236229
reinterpret_cast<std::int64_t*>(result_sort_index.data_ptr()),
237-
array_less<std::uint8_t, n_qubytes>()
230+
reinterpret_cast<std::int64_t*>(result_sort_index.data_ptr()) + result_batch_size,
231+
[&result_configs](std::int64_t i1, std::int64_t i2) {
232+
return array_less<std::uint8_t, n_qubytes>()(
233+
reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(result_configs.data_ptr())[i1],
234+
reinterpret_cast<const std::array<std::uint8_t, n_qubytes>*>(result_configs.data_ptr())[i2]
235+
);
236+
}
238237
);
238+
auto sorted_result_configs = result_configs.index({result_sort_index});
239+
auto sorted_result_psi = torch::zeros({result_batch_size, 2}, torch::TensorOptions().dtype(torch::kFloat64).device(device, device_id));
239240

240241
apply_within_kernel_interface<max_op_number, n_qubytes, particle_cut>(
241242
/*term_number=*/term_number,
@@ -255,7 +256,7 @@ auto apply_within_interface(
255256
return result_psi;
256257
}
257258

258-
template<typename T, typename Compare = thrust::less<T>>
259+
template<typename T, typename Compare = std::less<T>>
259260
void add_into_heap(T* heap, std::int64_t heap_size, const T& value) {
260261
auto compare = Compare();
261262
std::int64_t index = 0;
@@ -530,8 +531,7 @@ auto find_relative_interface(
530531

531532
auto sorted_exclude_configs = exclude_configs.clone(torch::MemoryFormat::Contiguous);
532533

533-
thrust::sort(
534-
thrust::host,
534+
std::sort(
535535
reinterpret_cast<std::array<std::uint8_t, n_qubytes>*>(sorted_exclude_configs.data_ptr()),
536536
reinterpret_cast<std::array<std::uint8_t, n_qubytes>*>(sorted_exclude_configs.data_ptr()) + exclude_size,
537537
array_less<std::uint8_t, n_qubytes>()

0 commit comments

Comments
 (0)