1+ #include " para_linear_transform.h"
2+ #include < vector>
3+ #include < algorithm>
4+ namespace hsolver
5+ {
6+ template <typename T, typename Device>
7+ void para_linear_transform_op<T, Device>::operator ()(T* A,
8+ const T alpha,
9+ const T beta,
10+ const T* U_global,
11+ const int & nrow,
12+ const int & LDA,
13+ const int & ncol_loc,
14+ const int & ncol_glo,
15+ #ifdef __MPI
16+ MPI_Comm col_world,
17+ #endif
18+ const int rank_col,
19+ const int nproc_col
20+
21+ )
22+ {
23+ const Device* ctx = {};
24+ #ifdef __MPI
25+ if (nproc_col > 1 )
26+ {
27+ std::vector<int > colA_loc (nproc_col);
28+ MPI_Allgather (&ncol_loc, 1 , MPI_INT, colA_loc.data (), 1 , MPI_INT, col_world);
29+ std::vector<int > start_col (nproc_col);
30+ start_col[0 ] = 0 ;
31+ for (int ip = 1 ; ip < nproc_col; ++ip)
32+ {
33+ start_col[ip] = start_col[ip - 1 ] + colA_loc[ip - 1 ];
34+ }
35+ int max_col = *std::max_element (colA_loc.begin (), colA_loc.end ());
36+ std::vector<MPI_Request> requests (nproc_col);
37+
38+ std::vector<T> A_tmp (max_col * LDA);
39+ T* A_tmp_device = A_tmp.data ();
40+ if (std::is_same<Device, base_device::DEVICE_GPU>::value)
41+ {
42+ A_tmp_device = nullptr ;
43+ resmem_dev_op ()(A_tmp_device, max_col * LDA);
44+ }
45+ T* A_tmp2 = nullptr ;
46+ resmem_dev_op ()(A_tmp2, ncol_loc * LDA);
47+ syncmem_dev_op ()(A_tmp2, A, ncol_loc * LDA);
48+ T* A_sum = nullptr ;
49+ resmem_dev_op ()(A_sum, ncol_loc * LDA);
50+ setmem_dev_op ()(A_sum, 0.0 , ncol_loc * LDA);
51+
52+ // Send
53+ for (int ip = 0 ; ip < nproc_col; ++ip)
54+ {
55+ if (rank_col != ip)
56+ {
57+ int size = LDA * ncol_loc;
58+ Parallel_Common::isend_dev<T, Device>(A, size, ip, 0 , col_world, &requests[ip], A_tmp.data ());
59+ }
60+ }
61+
62+ // Receive
63+ T* U_local = nullptr ;
64+ resmem_dev_op ()(U_local, max_col * ncol_loc);
65+ const int start = start_col[rank_col];
66+ for (int ip = 0 ; ip < nproc_col; ++ip)
67+ {
68+ T real_beta = ip == 0 ? beta : 0 ;
69+ const int start_row = start_col[ip];
70+ const int ncol_ip = colA_loc[ip];
71+ // get U_local
72+ for (int i = 0 ; i < ncol_loc; ++i)
73+ {
74+ const T* U_glo_tmp = U_global + start_row + (i + start) * ncol_glo;
75+ syncmem_dev_op ()(U_local + i * ncol_ip, U_glo_tmp, ncol_ip);
76+ }
77+
78+ if (ip == rank_col)
79+ {
80+ ModuleBase::gemm_op<T, Device>()(ctx,
81+ ' N' ,
82+ ' N' ,
83+ nrow,
84+ ncol_loc,
85+ ncol_ip,
86+ &alpha,
87+ A,
88+ LDA,
89+ U_local,
90+ ncol_ip,
91+ &real_beta,
92+ A_tmp2,
93+ LDA);
94+ }
95+ else
96+ {
97+ int size = LDA * ncol_ip;
98+ MPI_Status status;
99+ Parallel_Common::recv_dev<T, Device>(A_tmp_device, size, ip, 0 , col_world, &status, A_tmp.data ());
100+ MPI_Wait (&requests[ip], &status);
101+ ModuleBase::gemm_op<T, Device>()(ctx,
102+ ' N' ,
103+ ' N' ,
104+ nrow,
105+ ncol_loc,
106+ ncol_ip,
107+ &alpha,
108+ A_tmp_device,
109+ LDA,
110+ U_local,
111+ ncol_ip,
112+ &real_beta,
113+ A_tmp2,
114+ LDA);
115+ }
116+ // sum all the results
117+ T one = 1.0 ;
118+ ModuleBase::axpy_op<T, Device>()(ctx, ncol_loc * LDA, &one, A_tmp2, 1 , A_sum, 1 );
119+ }
120+ syncmem_dev_op ()(A, A_sum, ncol_loc * LDA);
121+ delmem_dev_op ()(U_local);
122+ delmem_dev_op ()(A_tmp2);
123+ delmem_dev_op ()(A_sum);
124+ if (std::is_same<Device, base_device::DEVICE_GPU>::value)
125+ {
126+ delmem_dev_op ()(A_tmp_device);
127+ }
128+ }
129+ else
130+ #endif
131+ {
132+ T* A_tmp = nullptr ;
133+ resmem_dev_op ()(A_tmp, LDA * ncol_glo);
134+ syncmem_dev_op ()(A_tmp, A, LDA * ncol_loc);
135+ ModuleBase::gemm_op<T, Device>()(ctx,
136+ ' N' ,
137+ ' N' ,
138+ nrow,
139+ ncol_glo,
140+ ncol_glo,
141+ &alpha,
142+ A_tmp,
143+ LDA,
144+ U_global,
145+ ncol_glo,
146+ &beta,
147+ A,
148+ LDA);
149+ delmem_dev_op ()(A_tmp);
150+ }
151+ };
152+
153+ template struct para_linear_transform_op <double , base_device::DEVICE_CPU>;
154+ template struct para_linear_transform_op <std::complex <double >, base_device::DEVICE_CPU>;
155+ template struct para_linear_transform_op <std::complex <float >, base_device::DEVICE_CPU>;
156+ #if ((defined __CUDA) || (defined __ROCM))
157+ template struct para_linear_transform_op <double , base_device::DEVICE_GPU>;
158+ template struct para_linear_transform_op <std::complex <double >, base_device::DEVICE_GPU>;
159+ template struct para_linear_transform_op <std::complex <float >, base_device::DEVICE_GPU>;
160+ #endif
161+ } // namespace hsolver
0 commit comments