1
- #include < thrust/execution_policy.h>
2
- #include < thrust/sort.h>
3
1
#include < torch/extension.h>
4
2
5
3
namespace qmb_hamiltonian_cpu {
@@ -225,17 +223,20 @@ auto apply_within_interface(
225
223
TORCH_CHECK (coef.size (0 ) == term_number, " coef size must match the provided term_number." );
226
224
TORCH_CHECK (coef.size (1 ) == 2 , " coef must contain 2 elements for each term." );
227
225
228
- auto sorted_result_configs = result_configs.clone (torch::MemoryFormat::Contiguous);
229
226
auto result_sort_index = torch::arange (result_batch_size, torch::TensorOptions ().dtype (torch::kInt64 ).device (device, device_id));
230
- auto sorted_result_psi = torch::zeros ({result_batch_size, 2 }, torch::TensorOptions ().dtype (torch::kFloat64 ).device (device, device_id));
231
227
232
- thrust::sort_by_key (
233
- thrust::host,
234
- reinterpret_cast <std::array<std::uint8_t , n_qubytes>*>(sorted_result_configs.data_ptr ()),
235
- reinterpret_cast <std::array<std::uint8_t , n_qubytes>*>(sorted_result_configs.data_ptr ()) + result_batch_size,
228
+ std::sort (
236
229
reinterpret_cast <std::int64_t *>(result_sort_index.data_ptr ()),
237
- array_less<std::uint8_t , n_qubytes>()
230
+ reinterpret_cast <std::int64_t *>(result_sort_index.data_ptr ()) + result_batch_size,
231
+ [&result_configs](std::int64_t i1, std::int64_t i2) {
232
+ return array_less<std::uint8_t , n_qubytes>()(
233
+ reinterpret_cast <const std::array<std::uint8_t , n_qubytes>*>(result_configs.data_ptr ())[i1],
234
+ reinterpret_cast <const std::array<std::uint8_t , n_qubytes>*>(result_configs.data_ptr ())[i2]
235
+ );
236
+ }
238
237
);
238
+ auto sorted_result_configs = result_configs.index ({result_sort_index});
239
+ auto sorted_result_psi = torch::zeros ({result_batch_size, 2 }, torch::TensorOptions ().dtype (torch::kFloat64 ).device (device, device_id));
239
240
240
241
apply_within_kernel_interface<max_op_number, n_qubytes, particle_cut>(
241
242
/* term_number=*/ term_number,
@@ -255,7 +256,7 @@ auto apply_within_interface(
255
256
return result_psi;
256
257
}
257
258
258
- template <typename T, typename Compare = thrust ::less<T>>
259
+ template <typename T, typename Compare = std ::less<T>>
259
260
void add_into_heap (T* heap, std::int64_t heap_size, const T& value) {
260
261
auto compare = Compare ();
261
262
std::int64_t index = 0 ;
@@ -530,8 +531,7 @@ auto find_relative_interface(
530
531
531
532
auto sorted_exclude_configs = exclude_configs.clone (torch::MemoryFormat::Contiguous);
532
533
533
- thrust::sort (
534
- thrust::host,
534
+ std::sort (
535
535
reinterpret_cast <std::array<std::uint8_t , n_qubytes>*>(sorted_exclude_configs.data_ptr ()),
536
536
reinterpret_cast <std::array<std::uint8_t , n_qubytes>*>(sorted_exclude_configs.data_ptr ()) + exclude_size,
537
537
array_less<std::uint8_t , n_qubytes>()
0 commit comments