Skip to content

Commit 02f2e1f

Browse files
committed
1. optimize CS_Matrix_Tools::cal_uplimit()
2. update Blas_Interface::matcopy()
1 parent 07afd82 commit 02f2e1f

File tree

8 files changed

+297
-15
lines changed

8 files changed

+297
-15
lines changed

include/RI/global/Blas_Interface-Contiguous.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,58 @@ namespace Blas_Interface
100100
}
101101
}
102102

103+
104+
105+
#ifdef __MKL_RI
106+
107+
namespace Blas_Interface
108+
{
109+
inline size_t get_lda_matcopy(const char ordering, size_t rows, size_t cols)
110+
{
111+
switch(std::toupper(ordering))
112+
{
113+
case 'R': return cols;
114+
case 'C': return rows;
115+
default: throw std::invalid_argument("ordering cannot be "+std::to_string(ordering)+". "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
116+
}
117+
}
118+
inline size_t get_ldb_matcopy(const char ordering, const char trans, size_t rows, size_t cols)
119+
{
120+
switch(std::toupper(ordering))
121+
{
122+
case 'R':
123+
switch(std::toupper(trans))
124+
{
125+
case 'N': case 'R': return cols;
126+
case 'T': case 'C': return rows;
127+
default: throw std::invalid_argument("trans cannot be "+std::to_string(trans)+". "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
128+
}
129+
case 'C':
130+
switch(std::toupper(trans))
131+
{
132+
case 'N': case 'R': return rows;
133+
case 'T': case 'C': return cols;
134+
default: throw std::invalid_argument("trans cannot be "+std::to_string(trans)+". "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
135+
}
136+
default: throw std::invalid_argument("ordering cannot be "+std::to_string(ordering)+". "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
137+
}
138+
}
139+
template<typename T>
140+
inline void imatcopy (const char ordering, const char trans, size_t rows, size_t cols, const T alpha, T * AB)
141+
{
142+
const size_t lda = get_lda_matcopy(ordering, rows, cols);
143+
const size_t ldb = get_ldb_matcopy(ordering, trans, rows, cols);
144+
imatcopy (ordering, trans, rows, cols, alpha, AB, lda, ldb);
145+
}
146+
template<typename T>
147+
inline void omatcopy (char ordering, char trans, size_t rows, size_t cols, const T alpha, const T * A, T * B)
148+
{
149+
const size_t lda = get_lda_matcopy(ordering, rows, cols);
150+
const size_t ldb = get_ldb_matcopy(ordering, trans, rows, cols);
151+
omatcopy (ordering, trans, rows, cols, alpha, A, lda, B, ldb);
152+
}
153+
}
154+
155+
#endif
156+
103157
}

include/RI/global/Blas_Interface-Tensor.h

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ namespace RI
1616
namespace Blas_Interface
1717
{
1818
// nrm2 = ||x||_2
19-
template<typename T>
20-
inline Global_Func::To_Real_t<T> nrm2(const Tensor<T> &X)
19+
template<typename T, template<typename> class Tvec>
20+
inline Global_Func::To_Real_t<T> nrm2(const Tvec<T> &X)
2121
{
2222
return nrm2(X.get_shape_all(), X.ptr());
2323
}
@@ -187,6 +187,41 @@ namespace Blas_Interface
187187
}
188188
}
189189

190+
191+
#ifdef __MKL_RI
192+
193+
namespace Blas_Interface
194+
{
195+
template<typename T>
196+
inline void imatcopy (const char trans, const T alpha, Tensor<T> &AB)
197+
{
198+
assert(AB.shape.size()==2);
199+
imatcopy ('R', trans, AB.shape[0], AB.shape[1], alpha, AB.ptr());
200+
switch(std::toupper(trans))
201+
{
202+
case 'N': case 'R': break;
203+
case 'T': case 'C': AB=AB.reshape({AB.shape[1], AB.shape[0]}); break;
204+
default: throw std::invalid_argument("trans cannot be "+std::to_string(trans)+". "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
205+
}
206+
}
207+
template<typename T>
208+
inline Tensor<T> omatcopy (char trans, const T alpha, const Tensor<T> &A)
209+
{
210+
assert(A.shape.size()==2);
211+
Tensor<T> B;
212+
switch(std::toupper(trans))
213+
{
214+
case 'N': case 'R': B = Tensor<T>({A.shape[0], A.shape[1]}); break;
215+
case 'T': case 'C': B = Tensor<T>({A.shape[1], A.shape[0]}); break;
216+
default: throw std::invalid_argument("trans cannot be "+std::to_string(trans)+". "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
217+
}
218+
omatcopy ('R', trans, A.shape[0], A.shape[1], alpha, A.ptr(), B.ptr());
219+
return B;
220+
}
221+
}
222+
223+
#endif
224+
190225
}
191226

192227
#include "Tensor.hpp"

include/RI/global/Blas_Interface.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
#include <string>
1111
#include <stdexcept>
1212

13+
14+
#ifdef __MKL_RI
15+
#include <mkl_trans.h>
16+
#endif
17+
1318
namespace RI
1419
{
1520

@@ -211,4 +216,47 @@ namespace Blas_Interface
211216
}
212217
}
213218

219+
220+
221+
#ifdef __MKL_RI
222+
223+
namespace Blas_Interface
224+
{
225+
inline void imatcopy (const char ordering, const char trans, size_t rows, size_t cols, const float alpha, float * AB, size_t lda, size_t ldb)
226+
{
227+
mkl_simatcopy (ordering, trans, rows, cols, alpha, AB, lda, ldb);
228+
}
229+
inline void imatcopy (const char ordering, const char trans, size_t rows, size_t cols, const double alpha, double * AB, size_t lda, size_t ldb)
230+
{
231+
mkl_dimatcopy (ordering, trans, rows, cols, alpha, AB, lda, ldb);
232+
}
233+
inline void imatcopy (const char ordering, const char trans, size_t rows, size_t cols, const std::complex<float> alpha, std::complex<float> * AB, size_t lda, size_t ldb)
234+
{
235+
mkl_cimatcopy (ordering, trans, rows, cols, alpha, AB, lda, ldb);
236+
}
237+
inline void imatcopy (const char ordering, const char trans, size_t rows, size_t cols, const std::complex<double> alpha, std::complex<double> * AB, size_t lda, size_t ldb)
238+
{
239+
mkl_zimatcopy (ordering, trans, rows, cols, alpha, AB, lda, ldb);
240+
}
241+
242+
inline void omatcopy (char ordering, char trans, size_t rows, size_t cols, const float alpha, const float * A, size_t lda, float * B, size_t ldb)
243+
{
244+
mkl_somatcopy (ordering, trans, rows, cols, alpha, A, lda, B, ldb);
245+
}
246+
inline void omatcopy (char ordering, char trans, size_t rows, size_t cols, const double alpha, const double * A, size_t lda, double * B, size_t ldb)
247+
{
248+
mkl_domatcopy (ordering, trans, rows, cols, alpha, A, lda, B, ldb);
249+
}
250+
inline void omatcopy (char ordering, char trans, size_t rows, size_t cols, const std::complex<float> alpha, const std::complex<float> * A, size_t lda, std::complex<float> * B, size_t ldb)
251+
{
252+
mkl_comatcopy (ordering, trans, rows, cols, alpha, A, lda, B, ldb);
253+
}
254+
inline void omatcopy (char ordering, char trans, size_t rows, size_t cols, const std::complex<double> alpha, const std::complex<double> * A, size_t lda, std::complex<double> * B, size_t ldb)
255+
{
256+
mkl_zomatcopy (ordering, trans, rows, cols, alpha, A, lda, B, ldb);
257+
}
258+
}
259+
260+
#endif
261+
214262
}

include/RI/global/Tensor_Wrapper.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2022.12.09
4+
// ===================
5+
6+
#pragma once
7+
8+
#include "Global_Func-2.h"
9+
#include <vector>
10+
11+
12+
#include <numeric>
13+
14+
// Attention: very dangerous
15+
16+
namespace RI
17+
{
18+
19+
template<typename T>
20+
class Tensor_Wrapper
21+
{
22+
public:
23+
24+
std::vector<std::size_t> shape;
25+
T *ptr_ = nullptr;
26+
27+
Tensor_Wrapper()=default;
28+
explicit inline Tensor_Wrapper (const std::vector<std::size_t> &shape_in, T*const ptr_in) :shape(shape_in), ptr_(ptr_in){}
29+
30+
T* ptr()const{ return this->ptr_; }
31+
inline std::size_t get_shape_all() const;
32+
33+
// ||d||_p = (|d_1|^p+|d_2|^p+...)^{1/p}
34+
// if(p==std::numeric_limits<double>::max()) ||d||_max = max_i |d_i|
35+
Global_Func::To_Real_t<T> norm(const double p) const;
36+
};
37+
38+
}
39+
40+
#include "Tensor_Wrapper.hpp"
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2022.12.09
4+
// ===================
5+
6+
#pragma once
7+
8+
#include "Tensor_Wrapper.h"
9+
10+
namespace RI
11+
{
12+
template<typename T>
13+
std::size_t Tensor_Wrapper<T>::get_shape_all() const
14+
{
15+
return std::accumulate(this->shape.begin(), this->shape.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>() );
16+
}
17+
18+
template<typename T>
19+
Global_Func::To_Real_t<T> Tensor_Wrapper<T>::norm(const double p) const
20+
{
21+
using T_res = Global_Func::To_Real_t<T>;
22+
const std::size_t shape_all = get_shape_all();
23+
if(p==2)
24+
return Blas_Interface::nrm2(*this);
25+
else if(p==1)
26+
{
27+
T_res s = 0;
28+
for(std::size_t i=0; i<shape_all; ++i)
29+
s += std::abs(this->ptr_[i]);
30+
return s;
31+
}
32+
else if(p==std::numeric_limits<double>::max())
33+
{
34+
T_res s = 0;
35+
for(std::size_t i=0; i<shape_all; ++i)
36+
s = std::max(std::real(s), std::abs(this->ptr_[i]));
37+
return s;
38+
}
39+
else
40+
{
41+
T_res s = 0;
42+
for(std::size_t i=0; i<shape_all; ++i)
43+
s += std::pow(std::abs(this->ptr_[i]), p);
44+
return std::pow(s,1.0/p);
45+
}
46+
}
47+
}

include/RI/ri/CS_Matrix_Tools.hpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "CS_Matrix_Tools.h"
99
#include "../global/Blas_Interface-Tensor.h"
10+
#include "../global/Tensor_Wrapper.h"
1011
#include <stdexcept>
1112
#include <memory.h>
1213

@@ -104,6 +105,36 @@ namespace CS_Matrix_Tools
104105
return uplimits.max();
105106
};
106107

108+
auto three_2_norm = [&D]() -> Tlim
109+
{
110+
#ifdef __MKL_RI
111+
const Tensor<Tdata> Ds_sub = Blas_Interface::omatcopy(
112+
'T', Tdata{1.0},
113+
D.reshape({D.shape[0]*D.shape[1],D.shape[2]}));
114+
#else
115+
Tensor<Tdata> Ds_sub({D.shape[2], D.shape[0]*D.shape[1]});
116+
117+
std::vector<Tdata*> Ds_sub_ptr(D.shape[2]);
118+
for(std::size_t i2=0; i2<D.shape[2]; ++i2)
119+
Ds_sub_ptr[i2] = Ds_sub.ptr()+i2*Ds_sub.shape[1]-1;
120+
121+
const Tdata* D_ptr = D.ptr()-1;
122+
const std::size_t size2 = D.shape[2];
123+
for(std::size_t i01=0; i01<Ds_sub.shape[1]; ++i01)
124+
for(std::size_t i2=0; i2<size2; ++i2)
125+
*(++Ds_sub_ptr[i2]) = *(++D_ptr);
126+
#endif
127+
128+
Tensor_Wrapper<Tdata> D_sub({D.shape[0],D.shape[1]}, nullptr);
129+
std::valarray<Global_Func::To_Real_t<Tdata>> uplimits(D.shape[2]);
130+
for(std::size_t i2=0; i2<D.shape[2]; ++i2)
131+
{
132+
D_sub.ptr_ = Ds_sub.ptr()+i2*Ds_sub.shape[1];
133+
uplimits[i2] = D_sub.norm(2);
134+
}
135+
return uplimits.max();
136+
};
137+
107138
auto norm = [](const Tensor<Tdata> &D) -> Tlim
108139
{
109140
return D.norm(2);
@@ -123,7 +154,8 @@ namespace CS_Matrix_Tools
123154
case Uplimit_Type::norm_three_1:
124155
return three_1(norm);
125156
case Uplimit_Type::norm_three_2:
126-
return three_2(norm);
157+
// return three_2(norm);
158+
return three_2_norm();
127159
case Uplimit_Type::square_two:
128160
return square(D);
129161
case Uplimit_Type::square_three_0:

include/RI/ri/LRI-cal.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "../global/Array_Operator.h"
1111

1212
#include <omp.h>
13-
#ifdef __MKL
13+
#ifdef __MKL_RI
1414
#include <mkl_service.h>
1515
#endif
1616

@@ -41,7 +41,7 @@ void LRI<TA,Tcell,Ndim,Tdata>::cal(
4141
omp_lock_t lock_Ds_result_add;
4242
omp_init_lock(&lock_Ds_result_add);
4343

44-
#ifdef __MKL
44+
#ifdef __MKL_RI
4545
const std::size_t mkl_threads = mkl_get_max_threads();
4646
// if(!omp_get_nested())
4747
// mkl_set_num_threads(std::max(1UL,mkl_threads/list_Aa01.size()));
@@ -120,7 +120,7 @@ void LRI<TA,Tcell,Ndim,Tdata>::cal(
120120
} // end #pragma omp parallel
121121

122122
omp_destroy_lock(&lock_Ds_result_add);
123-
#ifdef __MKL
123+
#ifdef __MKL_RI
124124
mkl_set_num_threads(mkl_threads);
125125
#endif
126126
}

0 commit comments

Comments
 (0)