Skip to content

Commit e7e6a4d

Browse files
authored
modify interfaces cdotc,zdotc,cdotu,zdotu to support openblas (#3)
* implement dotc and dotu by sdot_ and ddot_ * UT for dotc and dotu with different incX and incY * fix a typo
1 parent 16539e1 commit e7e6a4d

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

include/RI/global/Blas-Fortran.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ extern "C"
3131
// d = Vx * Vy
3232
// reason for passing results as argument instead of returning it:
3333
// https://www.numbercrunch.de/blog/2014/07/lost-in-translation/
34-
void cdotu_(std::complex<float>*const result, const int*const n, const std::complex<float>*const X, const int*const incX, const std::complex<float>*const Y, const int*const incY);
35-
void zdotu_(std::complex<double>*const result, const int*const n, const std::complex<double>*const X, const int*const incX, const std::complex<double>*const Y, const int*const incY);
34+
// void cdotu_(std::complex<float>*const result, const int*const n, const std::complex<float>*const X, const int*const incX, const std::complex<float>*const Y, const int*const incY);
35+
// void zdotu_(std::complex<double>*const result, const int*const n, const std::complex<double>*const X, const int*const incX, const std::complex<double>*const Y, const int*const incY);
3636

3737
// d = Vx * Vy
38-
void cdotc_(std::complex<float>*const result, const int*const n, const std::complex<float>*const X, const int*const incX, const std::complex<float>*const Y, const int*const incY);
39-
void zdotc_(std::complex<double>*const result, const int*const n, const std::complex<double>*const X, const int*const incX, const std::complex<double>*const Y, const int*const incY);
38+
// void cdotc_(std::complex<float>*const result, const int*const n, const std::complex<float>*const X, const int*const incX, const std::complex<float>*const Y, const int*const incY);
39+
// void zdotc_(std::complex<double>*const result, const int*const n, const std::complex<double>*const X, const int*const incX, const std::complex<double>*const Y, const int*const incY);
4040

4141
// Vy = alpha * Ma.? * Vx + beta * Vy
4242
void sgemv_(const char*const transA, const int*const m, const int*const n,

include/RI/global/Blas_Interface.h

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,27 @@ namespace Blas_Interface
9696
}
9797
inline std::complex<float> dotu(const int n, const std::complex<float>*const X, const int incX, const std::complex<float>*const Y, const int incY)
9898
{
99-
std::complex<float> result;
100-
cdotu_(&result, &n, X, &incX, Y, &incY);
101-
return result;
99+
//cdotu_(&result, &n, X, &incX, Y, &incY);
100+
const int incX2 = 2 * incX;
101+
const int incY2 = 2 * incY;
102+
auto x = reinterpret_cast<const float*>(X);
103+
auto y = reinterpret_cast<const float*>(Y);
104+
//Re(result)=Re(x)*Re(y)-Im(x)*Im(y)
105+
//Im(result)=Re(x)*Im(y)+Im(x)*Re(y)
106+
return std::complex<float>(sdot_(&n, x, &incX2, y, &incY2) - sdot_(&n, x + 1, &incX2, y + 1, &incY2),
107+
sdot_(&n, x, &incX2, y + 1, &incY2) + sdot_(&n, x + 1, &incX2, y, &incY2));
102108
}
103109
inline std::complex<double> dotu(const int n, const std::complex<double>*const X, const int incX, const std::complex<double>*const Y, const int incY)
104110
{
105-
std::complex<double> result;
106-
zdotu_(&result, &n, X, &incX, Y, &incY);
107-
return result;
111+
//zdotu_(&result, &n, X, &incX, Y, &incY);
112+
const int incX2 = 2 * incX;
113+
const int incY2 = 2 * incY;
114+
auto x = reinterpret_cast<const double*>(X);
115+
auto y = reinterpret_cast<const double*>(Y);
116+
//Re(result)=Re(x)*Re(y)-Im(x)*Im(y)
117+
//Im(result)=Re(x)*Im(y)+Im(x)*Re(y)
118+
return std::complex<double>(ddot_(&n, x, &incX2, y, &incY2) - ddot_(&n, x + 1, &incX2, y + 1, &incY2),
119+
ddot_(&n, x, &incX2, y + 1, &incY2) + ddot_(&n, x + 1, &incX2, y, &incY2));
108120
}
109121

110122
// d = Vx.conj() * Vy
@@ -118,15 +130,27 @@ namespace Blas_Interface
118130
}
119131
inline std::complex<float> dotc(const int n, const std::complex<float>*const X, const int incX, const std::complex<float>*const Y, const int incY)
120132
{
121-
std::complex<float> result;
122-
cdotc_(&result, &n, X, &incX, Y, &incY);
123-
return result;
133+
//cdotc_(&result, &n, X, &incX, Y, &incY);
134+
const int incX2 = 2 * incX;
135+
const int incY2 = 2 * incY;
136+
auto x = reinterpret_cast<const float*>(X);
137+
auto y = reinterpret_cast<const float*>(Y);
138+
// Re(result)=Re(X)*Re(Y)+Im(X)*Im(Y)
139+
// Im(result)=Re(X)*Im(Y)-Im(X)*Re(Y)
140+
return std::complex<float>(sdot_(&n, x, &incX2, y, &incY2) + sdot_(&n, x + 1, &incX2, y + 1, &incY2),
141+
sdot_(&n, x, &incX2, y + 1, &incY2) - sdot_(&n, x + 1, &incX2, y, &incY2));
124142
}
125143
inline std::complex<double> dotc(const int n, const std::complex<double>*const X, const int incX, const std::complex<double>*const Y, const int incY)
126144
{
127-
std::complex<double> result;
128-
zdotc_(&result, &n, X, &incX, Y, &incY);
129-
return result;
145+
//zdotc_(&result, &n, X, &incX, Y, &incY);
146+
const int incX2 = 2 * incX;
147+
const int incY2 = 2 * incY;
148+
auto x = reinterpret_cast<const double*>(X);
149+
auto y = reinterpret_cast<const double*>(Y);
150+
// Re(result)=Re(X)*Re(Y)+Im(X)*Im(Y)
151+
// Im(result)=Re(X)*Im(Y)-Im(X)*Re(Y)
152+
return std::complex<double>(ddot_(&n, x, &incX2, y, &incY2) + ddot_(&n, x + 1, &incX2, y + 1, &incY2),
153+
ddot_(&n, x, &incX2, y+1, &incY2) - ddot_(&n, x + 1, &incX2, y, &incY2));
130154
}
131155

132156
// Vy = alpha * Ma.? * Vx + beta * Vy

unittests/global/Blas-test.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ namespace Blas_Test
7272
/* -18+68i */
7373
std::cout<<RI::Blas_Interface::dotc(a.size(), a.data(), 1, b.data(), 1)<<std::endl;
7474
/* 70-8i */
75+
76+
// test for different incX and incY
77+
const std::vector<Tdata> a3 = { {1,2}, {1,4}, {4,2}, {3,4}, {2,8}, {1,8}};
78+
const std::vector<Tdata> b2 = { {5,6}, {5,7}, {7,8}, {7,5}};
79+
std::cout << RI::Blas_Interface::dotu(a.size(), a3.data(), 3, b2.data(), 2) << std::endl;
80+
/* -18+68i */
81+
std::cout << RI::Blas_Interface::dotc(a.size(), a3.data(), 3, b2.data(), 2) << std::endl;
82+
/* 70-8i */
7583
}
7684

7785
template<typename Tdata>

0 commit comments

Comments
 (0)