33#ifdef __DSP
44
55#include " module_base/module_device/device.h"
6+ #include " module_base/module_device/memory_op.h"
67#include " module_hsolver/diag_comm_info.h"
78
89// Base dsp functions
@@ -67,18 +68,17 @@ void cgemm_mth_(const char *transa, const char *transb,
6768
6869// The next is dsp utils. It may be moved to other files if this file get too huge
6970
70- Device* ctx = {};
7171base_device::DEVICE_CPU* cpu_ctx = {};
7272base_device::AbacusDevice_t device = {};
7373
7474template <typename T>
75- void dsp_dav_subspace_reduce (T* hcc, T* scc, int nbase_x, int notconv, const diag_comm_info diag_comm){
75+ void dsp_dav_subspace_reduce (T* hcc, T* scc, int nbase_x, int notconv, MPI_Comm diag_comm){
7676
7777 using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
7878
7979 auto * swap = new T[notconv * nbase_x];
8080 auto * target = new T[notconv * nbase_x];
81- syncmem_complex_op ()(ctx, ctx , swap, hcc + nbase * nbase_x, notconv * nbase_x);
81+ syncmem_complex_op ()(cpu_ctx, cpu_ctx , swap, hcc + nbase * nbase_x, notconv * nbase_x);
8282 if (base_device::get_current_precision (swap) == " single" )
8383 {
8484 MPI_Reduce (swap,
@@ -87,7 +87,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase_x, int notconv, const dia
8787 MPI_COMPLEX,
8888 MPI_SUM,
8989 0 ,
90- diag_comm. comm );
90+ diag_comm);
9191 }
9292 else
9393 {
@@ -97,11 +97,11 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase_x, int notconv, const dia
9797 MPI_DOUBLE_COMPLEX,
9898 MPI_SUM,
9999 0 ,
100- diag_comm. comm );
100+ diag_comm);
101101 }
102102
103- syncmem_complex_op ()(ctx, ctx , hcc + nbase * nbase_x, target, notconv * nbase_x);
104- syncmem_complex_op ()(ctx, ctx , swap, scc + nbase * nbase_x, notconv * nbase_x);
103+ syncmem_complex_op ()(cpu_ctx, cpu_ctx , hcc + nbase * nbase_x, target, notconv * nbase_x);
104+ syncmem_complex_op ()(cpu_ctx, cpu_ctx , swap, scc + nbase * nbase_x, notconv * nbase_x);
105105
106106 if (base_device::get_current_precision (swap) == " single" )
107107 {
@@ -111,7 +111,7 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase_x, int notconv, const dia
111111 MPI_COMPLEX,
112112 MPI_SUM,
113113 0 ,
114- diag_comm. comm );
114+ diag_comm);
115115 }
116116 else
117117 {
@@ -121,10 +121,10 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase_x, int notconv, const dia
121121 MPI_DOUBLE_COMPLEX,
122122 MPI_SUM,
123123 0 ,
124- diag_comm. comm );
124+ diag_comm);
125125 }
126126
127- syncmem_complex_op ()(ctx, ctx , scc + nbase * nbase_x, target, notconv * nbase_x);
127+ syncmem_complex_op ()(cpu_ctx, cpu_ctx , scc + nbase * nbase_x, target, notconv * nbase_x);
128128 delete[] swap;
129129 delete[] target;
130130}
0 commit comments