@@ -674,6 +674,134 @@ auto find_relative_interface(
674
674
return unique_nonzero_result_config;
675
675
}
676
676
677
+ template <std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
678
+ __device__ void diagonal_term_kernel (
679
+ std::int64_t term_index,
680
+ std::int64_t batch_index,
681
+ std::int64_t term_number,
682
+ std::int64_t batch_size,
683
+ const std::array<std::int16_t , max_op_number>* site, // term_number
684
+ const std::array<std::uint8_t , max_op_number>* kind, // term_number
685
+ const std::array<double , 2 >* coef, // term_number
686
+ const std::array<std::uint8_t , n_qubytes>* configs, // batch_size
687
+ std::array<double , 2 >* result_psi
688
+ ) {
689
+ std::array<std::uint8_t , n_qubytes> current_configs = configs[batch_index];
690
+ auto [success, parity] = hamiltonian_apply_kernel<max_op_number, n_qubytes, particle_cut>(
691
+ /* current_configs=*/ current_configs,
692
+ /* term_index=*/ term_index,
693
+ /* batch_index=*/ batch_index,
694
+ /* site=*/ site,
695
+ /* kind=*/ kind
696
+ );
697
+
698
+ if (!success) {
699
+ return ;
700
+ }
701
+ auto less = array_less<std::uint8_t , n_qubytes>();
702
+ if (less (current_configs, configs[batch_index]) || less (configs[batch_index], current_configs)) {
703
+ return ; // The term does not apply to the current configuration
704
+ }
705
+ std::int8_t sign = parity ? -1 : +1 ;
706
+ atomicAdd (&result_psi[batch_index][0 ], sign * coef[term_index][0 ]);
707
+ atomicAdd (&result_psi[batch_index][1 ], sign * coef[term_index][1 ]);
708
+ }
709
+
710
+ template <std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
711
+ __global__ void diagonal_term_kernel_interface (
712
+ std::int64_t term_number,
713
+ std::int64_t batch_size,
714
+ const std::array<std::int16_t , max_op_number>* site, // term_number
715
+ const std::array<std::uint8_t , max_op_number>* kind, // term_number
716
+ const std::array<double , 2 >* coef, // term_number
717
+ const std::array<std::uint8_t , n_qubytes>* configs, // batch_size
718
+ std::array<double , 2 >* result_psi
719
+ ) {
720
+ std::int64_t term_index = blockIdx .x * blockDim .x + threadIdx .x ;
721
+ std::int64_t batch_index = blockIdx .y * blockDim .y + threadIdx .y ;
722
+
723
+ if (term_index < term_number && batch_index < batch_size) {
724
+ diagonal_term_kernel<max_op_number, n_qubytes, particle_cut>(
725
+ /* term_index=*/ term_index,
726
+ /* batch_index=*/ batch_index,
727
+ /* term_number=*/ term_number,
728
+ /* batch_size=*/ batch_size,
729
+ /* site=*/ site,
730
+ /* kind=*/ kind,
731
+ /* coef=*/ coef,
732
+ /* configs=*/ configs,
733
+ /* result_psi=*/ result_psi
734
+ );
735
+ }
736
+ }
737
+
738
+ template <std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
739
+ auto diagonal_term_interface (const torch::Tensor& configs, const torch::Tensor& site, const torch::Tensor& kind, const torch::Tensor& coef)
740
+ -> torch::Tensor {
741
+ std::int64_t device_id = configs.device ().index ();
742
+ std::int64_t batch_size = configs.size (0 );
743
+ std::int64_t term_number = site.size (0 );
744
+ at::cuda::CUDAGuard cuda_device_guard (device_id);
745
+
746
+ TORCH_CHECK (configs.device ().type () == torch::kCUDA , " configs must be on CUDA." )
747
+ TORCH_CHECK (configs.device ().index () == device_id, " configs must be on the same device as others." );
748
+ TORCH_CHECK (configs.is_contiguous (), " configs must be contiguous." )
749
+ TORCH_CHECK (configs.dtype () == torch::kUInt8 , " configs must be uint8." )
750
+ TORCH_CHECK (configs.dim () == 2 , " configs must be 2D." )
751
+ TORCH_CHECK (configs.size (0 ) == batch_size, " configs batch size must match the provided batch_size." );
752
+ TORCH_CHECK (configs.size (1 ) == n_qubytes, " configs must have the same number of qubits as the provided n_qubytes." );
753
+
754
+ TORCH_CHECK (site.device ().type () == torch::kCUDA , " site must be on CUDA." )
755
+ TORCH_CHECK (site.device ().index () == device_id, " site must be on the same device as others." );
756
+ TORCH_CHECK (site.is_contiguous (), " site must be contiguous." )
757
+ TORCH_CHECK (site.dtype () == torch::kInt16 , " site must be int16." )
758
+ TORCH_CHECK (site.dim () == 2 , " site must be 2D." )
759
+ TORCH_CHECK (site.size (0 ) == term_number, " site size must match the provided term_number." );
760
+ TORCH_CHECK (site.size (1 ) == max_op_number, " site must match the provided max_op_number." );
761
+
762
+ TORCH_CHECK (kind.device ().type () == torch::kCUDA , " kind must be on CUDA." )
763
+ TORCH_CHECK (kind.device ().index () == device_id, " kind must be on the same device as others." );
764
+ TORCH_CHECK (kind.is_contiguous (), " kind must be contiguous." )
765
+ TORCH_CHECK (kind.dtype () == torch::kUInt8 , " kind must be uint8." )
766
+ TORCH_CHECK (kind.dim () == 2 , " kind must be 2D." )
767
+ TORCH_CHECK (kind.size (0 ) == term_number, " kind size must match the provided term_number." );
768
+ TORCH_CHECK (kind.size (1 ) == max_op_number, " kind must match the provided max_op_number." );
769
+
770
+ TORCH_CHECK (coef.device ().type () == torch::kCUDA , " coef must be on CUDA." )
771
+ TORCH_CHECK (coef.device ().index () == device_id, " coef must be on the same device as others." );
772
+ TORCH_CHECK (coef.is_contiguous (), " coef must be contiguous." )
773
+ TORCH_CHECK (coef.dtype () == torch::kFloat64 , " coef must be float64." )
774
+ TORCH_CHECK (coef.dim () == 2 , " coef must be 2D." )
775
+ TORCH_CHECK (coef.size (0 ) == term_number, " coef size must match the provided term_number." );
776
+ TORCH_CHECK (coef.size (1 ) == 2 , " coef must contain 2 elements for each term." );
777
+
778
+ auto stream = at::cuda::getCurrentCUDAStream (device_id);
779
+ auto policy = thrust::device.on (stream);
780
+
781
+ cudaDeviceProp prop;
782
+ AT_CUDA_CHECK (cudaGetDeviceProperties (&prop, device_id));
783
+ std::int64_t max_threads_per_block = prop.maxThreadsPerBlock ;
784
+
785
+ auto result_psi = torch::zeros ({batch_size, 2 }, torch::TensorOptions ().dtype (torch::kFloat64 ).device (device, device_id));
786
+
787
+ auto threads_per_block = dim3 {1 , max_threads_per_block >> 1 }; // I don't know why, but need to divide by 2 to avoid errors
788
+ auto num_blocks =
789
+ dim3 {(term_number + threads_per_block.x - 1 ) / threads_per_block.x , (batch_size + threads_per_block.y - 1 ) / threads_per_block.y };
790
+
791
+ diagonal_term_kernel_interface<max_op_number, n_qubytes, particle_cut><<<num_blocks, threads_per_block, 0 , stream>>> (
792
+ /* term_number=*/ term_number,
793
+ /* batch_size=*/ batch_size,
794
+ /* site=*/ reinterpret_cast <const std::array<std::int16_t , max_op_number>*>(site.data_ptr ()),
795
+ /* kind=*/ reinterpret_cast <const std::array<std::uint8_t , max_op_number>*>(kind.data_ptr ()),
796
+ /* coef=*/ reinterpret_cast <const std::array<double , 2 >*>(coef.data_ptr ()),
797
+ /* configs=*/ reinterpret_cast <const std::array<std::uint8_t , n_qubytes>*>(configs.data_ptr ()),
798
+ /* result_psi=*/ reinterpret_cast <std::array<double , 2 >*>(result_psi.data_ptr ())
799
+ );
800
+ AT_CUDA_CHECK (cudaStreamSynchronize (stream));
801
+
802
+ return result_psi;
803
+ }
804
+
677
805
template <std::int64_t max_op_number, std::int64_t n_qubytes, std::int64_t particle_cut>
678
806
__device__ void single_relative_kernel (
679
807
std::int64_t term_index,
@@ -880,6 +1008,7 @@ auto single_relative_interface(const torch::Tensor& configs, const torch::Tensor
880
1008
TORCH_LIBRARY_IMPL (QMB_LIBRARY(N_QUBYTES, PARTICLE_CUT), CUDA, m) {
881
1009
m.impl (" apply_within" , apply_within_interface</* max_op_number=*/ 4 , /* n_qubytes=*/ N_QUBYTES, /* particle_cut=*/ PARTICLE_CUT>);
882
1010
m.impl (" find_relative" , find_relative_interface</* max_op_number=*/ 4 , /* n_qubytes=*/ N_QUBYTES, /* particle_cut=*/ PARTICLE_CUT>);
1011
+ m.impl (" diagonal_term" , diagonal_term_interface</* max_op_number=*/ 4 , /* n_qubytes=*/ N_QUBYTES, /* particle_cut=*/ PARTICLE_CUT>);
883
1012
m.impl (" single_relative" , single_relative_interface</* max_op_number=*/ 4 , /* n_qubytes=*/ N_QUBYTES, /* particle_cut=*/ PARTICLE_CUT>);
884
1013
}
885
1014
#undef QMB_LIBRARY
0 commit comments