@@ -32,6 +32,8 @@ class TestModuleHamiltVeff : public ::testing::Test
3232 const base_device::DEVICE_CPU* cpu_ctx = {};
3333 const base_device::DEVICE_GPU* gpu_ctx = {};
3434
35+ using rearrange_cpu = hamilt::rearrange<double , base_device::DEVICE_CPU>;
36+ using rearrange_gpu = hamilt::rearrange<double , base_device::DEVICE_GPU>;
3537 using veff_cpu_op = hamilt::veff_pw_op<double , base_device::DEVICE_CPU>;
3638 using veff_gpu_op = hamilt::veff_pw_op<double , base_device::DEVICE_GPU>;
3739
@@ -70,13 +72,10 @@ TEST_F(TestModuleHamiltVeff, veff_pw_spin_op_cpu)
7072 std::vector<std::complex <double >> expected_out1_spin (out_spin.size (), std::complex <double >(0 , 0 ));
7173 std::vector<std::complex <double >> res = out_spin;
7274 std::vector<std::complex <double >> res1 = out1_spin;
75+ std::vector<std::complex <double >> in_ (4 * in_spin.size (), std::complex <double >(0 , 0 ));
76+ rearrange_cpu ()(cpu_ctx, this ->size , in_spin.data (), in_.data ());
7377
74- const double * in_[4 ];
75- for (int ii = 0 ; ii < 4 ; ii++) {
76- in_[ii] = in_spin.data () + ii * this ->size ;
77- }
78-
79- veff_cpu_op ()(cpu_ctx, this ->size , res.data (), res1.data (), in_);
78+ veff_cpu_op ()(cpu_ctx, this ->size , res.data (), res1.data (), in_.data ());
8079 for (int ii = 0 ; ii < res.size (); ii++) {
8180 EXPECT_LT (std::abs (res[ii] - expected_out_spin[ii]), 7.5e-5 );
8281 EXPECT_LT (std::abs (res1[ii] - expected_out1_spin[ii]), 6e-5 );
@@ -108,23 +107,22 @@ TEST_F(TestModuleHamiltVeff, veff_pw_spin_op_gpu)
108107{
109108 std::vector<std::complex <double >> out1_spin (out_spin.size (), std::complex <double >(0 , 0 ));
110109 std::vector<std::complex <double >> expected_out1_spin (out_spin.size (), std::complex <double >(0 , 0 ));
110+ std::vector<std::complex <double >> in_ (4 * in_spin.size (), std::complex <double >(0 , 0 ));
111111 std::vector<std::complex <double >> res = out_spin;
112112 std::vector<std::complex <double >> res1 = out1_spin;
113113 double * d_in = NULL ;
114+ std::complex <double >* d_in_ = NULL ;
114115 std::complex <double >* d_res = NULL , * d_res1 = NULL ;
115116 resize_memory_double_op ()(d_in, in_spin.size ());
116117 resize_memory_complex_op ()(d_res, res.size ());
117118 resize_memory_complex_op ()(d_res1, res1.size ());
119+ resize_memory_complex_op ()(d_in_, in_spin.size ()*4 );
118120 syncmem_double_h2d_op ()(d_in, in_spin.data (), in_spin.size ());
119121 syncmem_complex_h2d_op ()(d_res, res.data (), res.size ());
120122 syncmem_complex_h2d_op ()(d_res1, res1.data (), res1.size ());
121-
122- const double * in_[4 ];
123- for (int ii = 0 ; ii < 4 ; ii++) {
124- in_[ii] = d_in + ii * this ->size ;
125- }
126-
127- veff_gpu_op ()(gpu_ctx, this ->size , d_res, d_res1, in_);
123+
124+ rearrange_gpu ()(gpu_ctx, this ->size , d_in, d_in_);
125+ veff_gpu_op ()(gpu_ctx, this ->size , d_res, d_res1, d_in_);
128126
129127 syncmem_complex_d2h_op ()(res.data (), d_res, res.size ());
130128 syncmem_complex_d2h_op ()(res1.data (), d_res1, res1.size ());
@@ -135,5 +133,6 @@ TEST_F(TestModuleHamiltVeff, veff_pw_spin_op_gpu)
135133 delete_memory_double_op ()(d_in);
136134 delete_memory_complex_op ()(d_res);
137135 delete_memory_complex_op ()(d_res1);
136+ delete_memory_complex_op ()(d_in_);
138137}
139138#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
0 commit comments