1
1
#include < ATen/cuda/Exceptions.h>
2
2
#include < c10/cuda/CUDAStream.h>
3
3
#include < cuda_runtime.h>
4
+ #include < curand_kernel.h>
4
5
#include < thrust/sort.h>
5
6
#include < torch/extension.h>
6
7
@@ -670,6 +671,198 @@ auto find_relative_interface(
670
671
return unique_nonzero_result_config;
671
672
}
672
673
674
+ template <std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
675
+ __device__ void single_relative_kernel (
676
+ std::int64_t term_index,
677
+ std::int64_t batch_index,
678
+ std::int64_t term_number,
679
+ std::int64_t batch_size,
680
+ std::int64_t exclude_size,
681
+ std::uint64_t seed,
682
+ const std::array<std::int16_t , max_op_number>* site, // term_number
683
+ const std::array<std::uint8_t , max_op_number>* kind, // term_number
684
+ const std::array<double , 2 >* coef, // term_number
685
+ const std::array<std::uint8_t , n_qubytes>* configs, // batch_size
686
+ const std::array<std::uint8_t , n_qubytes>* exclude_configs, // exclude_size
687
+ std::array<std::uint8_t , n_qubytes>* result_configs, // batch_size
688
+ double * score, // batch_size
689
+ int * mutex // batch_size
690
+ ) {
691
+ std::array<std::uint8_t , n_qubytes> current_configs = configs[batch_index];
692
+ auto [success, parity] = hamiltonian_apply_kernel<max_op_number, n_qubytes, particle_cut>(
693
+ /* current_configs=*/ current_configs,
694
+ /* term_index=*/ term_index,
695
+ /* batch_index=*/ batch_index,
696
+ /* site=*/ site,
697
+ /* kind=*/ kind
698
+ );
699
+
700
+ if (!success) {
701
+ return ;
702
+ }
703
+ success = true ;
704
+ std::int64_t low = 0 ;
705
+ std::int64_t high = exclude_size - 1 ;
706
+ std::int64_t mid = 0 ;
707
+ auto compare = array_less<std::uint8_t , n_qubytes>();
708
+ while (low <= high) {
709
+ mid = (low + high) / 2 ;
710
+ if (compare (current_configs, exclude_configs[mid])) {
711
+ high = mid - 1 ;
712
+ } else if (compare (exclude_configs[mid], current_configs)) {
713
+ low = mid + 1 ;
714
+ } else {
715
+ success = false ;
716
+ break ;
717
+ }
718
+ }
719
+ if (!success) {
720
+ return ;
721
+ }
722
+
723
+ // Efraimidis-Spirakis Algorithm is used here.
724
+ auto weight = std::pow (coef[term_index][0 ] * coef[term_index][0 ] + coef[term_index][1 ] * coef[term_index][1 ], 0.5 );
725
+ curandState state;
726
+ curand_init (seed, term_index, 0 , &state);
727
+ auto key = std::pow (curand_uniform_double (&state), 1.0 / weight);
728
+ if (score[batch_index] < key) {
729
+ mutex_lock (&mutex[batch_index]);
730
+ if (score[batch_index] < key) {
731
+ score[batch_index] = key;
732
+ result_configs[batch_index] = current_configs;
733
+ }
734
+ mutex_unlock (&mutex[batch_index]);
735
+ }
736
+ }
737
+
738
+ template <std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
739
+ __global__ void single_relative_kernel_interface (
740
+ std::int64_t term_number,
741
+ std::int64_t batch_size,
742
+ std::int64_t exclude_size,
743
+ std::uint64_t seed,
744
+ const std::array<std::int16_t , max_op_number>* site, // term_number
745
+ const std::array<std::uint8_t , max_op_number>* kind, // term_number
746
+ const std::array<double , 2 >* coef, // term_number
747
+ const std::array<std::uint8_t , n_qubytes>* configs, // batch_size
748
+ const std::array<std::uint8_t , n_qubytes>* exclude_configs, // exclude_size
749
+ std::array<std::uint8_t , n_qubytes>* result_configs, // batch_size
750
+ double * score, // batch_size
751
+ int * mutex // batch_size
752
+ ) {
753
+ std::int64_t term_index = blockIdx .x * blockDim .x + threadIdx .x ;
754
+ std::int64_t batch_index = blockIdx .y * blockDim .y + threadIdx .y ;
755
+
756
+ if (term_index < term_number && batch_index < batch_size) {
757
+ single_relative_kernel<max_op_number, n_qubytes, particle_cut>(
758
+ /* term_index=*/ term_index,
759
+ /* batch_index=*/ batch_index,
760
+ /* term_number=*/ term_number,
761
+ /* batch_size=*/ batch_size,
762
+ /* exclude_size=*/ exclude_size,
763
+ /* seed=*/ seed,
764
+ /* site=*/ site,
765
+ /* kind=*/ kind,
766
+ /* coef=*/ coef,
767
+ /* configs=*/ configs,
768
+ /* exclude_configs=*/ exclude_configs,
769
+ /* result_configs=*/ result_configs,
770
+ /* score=*/ score,
771
+ /* mutex=*/ mutex
772
+ );
773
+ }
774
+ }
775
+
776
+ template <std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
777
+ auto single_relative_interface (const torch::Tensor& configs, const torch::Tensor& site, const torch::Tensor& kind, const torch::Tensor& coef)
778
+ -> torch::Tensor {
779
+ std::int64_t device_id = configs.device ().index ();
780
+ std::int64_t batch_size = configs.size (0 );
781
+ std::int64_t term_number = site.size (0 );
782
+
783
+ TORCH_CHECK (configs.device ().type () == torch::kCUDA , " configs must be on CUDA." )
784
+ TORCH_CHECK (configs.device ().index () == device_id, " configs must be on the same device as others." );
785
+ TORCH_CHECK (configs.is_contiguous (), " configs must be contiguous." )
786
+ TORCH_CHECK (configs.dtype () == torch::kUInt8 , " configs must be uint8." )
787
+ TORCH_CHECK (configs.dim () == 2 , " configs must be 2D." )
788
+ TORCH_CHECK (configs.size (0 ) == batch_size, " configs batch size must match the provided batch_size." );
789
+ TORCH_CHECK (configs.size (1 ) == n_qubytes, " configs must have the same number of qubits as the provided n_qubytes." );
790
+
791
+ TORCH_CHECK (site.device ().type () == torch::kCUDA , " site must be on CUDA." )
792
+ TORCH_CHECK (site.device ().index () == device_id, " site must be on the same device as others." );
793
+ TORCH_CHECK (site.is_contiguous (), " site must be contiguous." )
794
+ TORCH_CHECK (site.dtype () == torch::kInt16 , " site must be int16." )
795
+ TORCH_CHECK (site.dim () == 2 , " site must be 2D." )
796
+ TORCH_CHECK (site.size (0 ) == term_number, " site size must match the provided term_number." );
797
+ TORCH_CHECK (site.size (1 ) == max_op_number, " site must match the provided max_op_number." );
798
+
799
+ TORCH_CHECK (kind.device ().type () == torch::kCUDA , " kind must be on CUDA." )
800
+ TORCH_CHECK (kind.device ().index () == device_id, " kind must be on the same device as others." );
801
+ TORCH_CHECK (kind.is_contiguous (), " kind must be contiguous." )
802
+ TORCH_CHECK (kind.dtype () == torch::kUInt8 , " kind must be uint8." )
803
+ TORCH_CHECK (kind.dim () == 2 , " kind must be 2D." )
804
+ TORCH_CHECK (kind.size (0 ) == term_number, " kind size must match the provided term_number." );
805
+ TORCH_CHECK (kind.size (1 ) == max_op_number, " kind must match the provided max_op_number." );
806
+
807
+ TORCH_CHECK (coef.device ().type () == torch::kCUDA , " coef must be on CUDA." )
808
+ TORCH_CHECK (coef.device ().index () == device_id, " coef must be on the same device as others." );
809
+ TORCH_CHECK (coef.is_contiguous (), " coef must be contiguous." )
810
+ TORCH_CHECK (coef.dtype () == torch::kFloat64 , " coef must be float64." )
811
+ TORCH_CHECK (coef.dim () == 2 , " coef must be 2D." )
812
+ TORCH_CHECK (coef.size (0 ) == term_number, " coef size must match the provided term_number." );
813
+ TORCH_CHECK (coef.size (1 ) == 2 , " coef must contain 2 elements for each term." );
814
+
815
+ auto stream = at::cuda::getCurrentCUDAStream (device_id);
816
+ auto policy = thrust::device.on (stream);
817
+
818
+ cudaDeviceProp prop;
819
+ AT_CUDA_CHECK (cudaGetDeviceProperties (&prop, device_id));
820
+ std::int64_t max_threads_per_block = prop.maxThreadsPerBlock ;
821
+
822
+ auto sorted_configs = configs.clone (torch::MemoryFormat::Contiguous);
823
+
824
+ thrust::sort (
825
+ policy,
826
+ reinterpret_cast <std::array<std::uint8_t , n_qubytes>*>(sorted_configs.data_ptr ()),
827
+ reinterpret_cast <std::array<std::uint8_t , n_qubytes>*>(sorted_configs.data_ptr ()) + batch_size,
828
+ array_less<std::uint8_t , n_qubytes>()
829
+ );
830
+
831
+ auto seed_tensor = torch::randint (int64_t (0 ), int64_t (std::numeric_limits<std::int64_t >::max), {}, torch::TensorOptions ().dtype (torch::kInt64 ));
832
+ auto seed = *(int64_t *)(seed_tensor.data_ptr ());
833
+
834
+ auto result_configs = torch::zeros ({batch_size, n_qubytes}, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device, device_id));
835
+ auto score = torch::empty ({batch_size}, torch::TensorOptions ().dtype (torch::kFloat64 ).device (device, device_id))
836
+ .fill_ (-std::numeric_limits<double >::infinity ());
837
+ int * mutex;
838
+ AT_CUDA_CHECK (cudaMalloc (&mutex, sizeof (int ) * batch_size));
839
+ AT_CUDA_CHECK (cudaMemset (mutex, 0 , sizeof (int ) * batch_size));
840
+
841
+ auto threads_per_block = dim3 {1 , max_threads_per_block >> 1 }; // I don't know why, but need to divide by 2 to avoid errors
842
+ auto num_blocks =
843
+ dim3 {(term_number + threads_per_block.x - 1 ) / threads_per_block.x , (batch_size + threads_per_block.y - 1 ) / threads_per_block.y };
844
+
845
+ single_relative_kernel_interface<max_op_number, n_qubytes, particle_cut><<<num_blocks, threads_per_block, 0 , stream>>> (
846
+ /* term_number=*/ term_number,
847
+ /* batch_size=*/ batch_size,
848
+ /* exclude_size=*/ batch_size,
849
+ /* seed=*/ seed,
850
+ /* site=*/ reinterpret_cast <const std::array<std::int16_t , max_op_number>*>(site.data_ptr ()),
851
+ /* kind=*/ reinterpret_cast <const std::array<std::uint8_t , max_op_number>*>(kind.data_ptr ()),
852
+ /* coef=*/ reinterpret_cast <const std::array<double , 2 >*>(coef.data_ptr ()),
853
+ /* configs=*/ reinterpret_cast <const std::array<std::uint8_t , n_qubytes>*>(configs.data_ptr ()),
854
+ /* exclude_configs=*/ reinterpret_cast <const std::array<std::uint8_t , n_qubytes>*>(sorted_configs.data_ptr ()),
855
+ /* result_configs=*/ reinterpret_cast <std::array<std::uint8_t , n_qubytes>*>(result_configs.data_ptr ()),
856
+ /* score=*/ reinterpret_cast <double *>(score.data_ptr ()),
857
+ /* mutex=*/ mutex
858
+ );
859
+ AT_CUDA_CHECK (cudaStreamSynchronize (stream));
860
+
861
+ AT_CUDA_CHECK (cudaFree (mutex));
862
+
863
+ return result_configs;
864
+ }
865
+
673
866
#ifndef N_QUBYTES
674
867
#define N_QUBYTES 0
675
868
#endif
@@ -683,6 +876,7 @@ auto find_relative_interface(
683
876
TORCH_LIBRARY_IMPL (QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CUDA, m) {
684
877
m.impl (" apply_within" , apply_within_interface</* max_op_number=*/ 4 , /* n_qubytes=*/ N_QUBYTES, /* particle_cut=*/ PARTICLE_CUT>);
685
878
m.impl (" find_relative" , find_relative_interface</* max_op_number=*/ 4 , /* n_qubytes=*/ N_QUBYTES, /* particle_cut=*/ PARTICLE_CUT>);
879
+ m.impl (" single_relative" , single_relative_interface</* max_op_number=*/ 4 , /* n_qubytes=*/ N_QUBYTES, /* particle_cut=*/ PARTICLE_CUT>);
686
880
}
687
881
#undef QMB_LIBRARY
688
882
#undef QMB_LIBRARY_HELPER
0 commit comments