Skip to content

Commit 247cdd2

Browse files
committed
Add blas_copy
1 parent f689729 commit 247cdd2

File tree

6 files changed

+177
-80
lines changed

6 files changed

+177
-80
lines changed

source/source_base/module_container/ATen/kernels/blas.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@
33
namespace container {
44
namespace kernels {
55

6+
template <typename T>
7+
struct blas_copy<T, DEVICE_CPU> {
8+
void operator()(
9+
const int n,
10+
const T *x,
11+
const int incx,
12+
T *y,
13+
const int incy)
14+
{
15+
BlasConnector::copy(n, x, incx, y, incy);
16+
}
17+
};
18+
619
template <typename T>
720
struct blas_dot<T, DEVICE_CPU> {
821
void operator()(
@@ -175,6 +188,12 @@ struct blas_gemm_batched_strided<T, DEVICE_CPU> {
175188
};
176189

177190
// Explicitly instantiate functors for the types of functor registered.
191+
192+
template struct blas_copy<float , DEVICE_CPU>;
193+
template struct blas_copy<double, DEVICE_CPU>;
194+
template struct blas_copy<std::complex<float >, DEVICE_CPU>;
195+
template struct blas_copy<std::complex<double>, DEVICE_CPU>;
196+
178197
template struct blas_dot<float , DEVICE_CPU>;
179198
template struct blas_dot<double, DEVICE_CPU>;
180199
template struct blas_dot<std::complex<float >, DEVICE_CPU>;
@@ -221,4 +240,4 @@ template struct blas_gemm_batched_strided<std::complex<float >, DEVICE_CPU>;
221240
template struct blas_gemm_batched_strided<std::complex<double>, DEVICE_CPU>;
222241

223242
} // namespace kernels
224-
} // namespace container
243+
} // namespace container

source/source_base/module_container/ATen/kernels/blas.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
namespace container {
1010
namespace kernels {
1111

12+
template <typename T, typename Device>
13+
struct blas_copy {
14+
void operator()(
15+
const int n,
16+
const T *x,
17+
const int incx,
18+
T *y,
19+
const int incy);
20+
};
21+
22+
1223
template <typename T, typename Device>
1324
struct blas_dot {
1425
void operator()(
@@ -168,4 +179,4 @@ void destroyGpuBlasHandle(); // destory blas handle
168179
} // namespace kernels
169180
} // namespace container
170181

171-
#endif // ATEN_KERNELS_BLAS_H_
182+
#endif // ATEN_KERNELS_BLAS_H_

source/source_base/module_container/ATen/kernels/cuda/blas.cu

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ void destroyGpuBlasHandle() {
2222
}
2323
}
2424

25+
template <typename T>
26+
struct blas_copy<T, DEVICE_GPU> {
27+
void operator()(
28+
const int n,
29+
const T * x,
30+
const int incx,
31+
T *y,
32+
const int incy)
33+
{
34+
cuBlasConnector::copy(cublas_handle, n, x, incx, y, incy);
35+
}
36+
};
2537

2638
template <typename T>
2739
struct blas_dot<T, DEVICE_GPU> {
@@ -76,7 +88,7 @@ struct blas_gemv<T, DEVICE_GPU> {
7688
const int& incx,
7789
const T* beta,
7890
T* y,
79-
const int& incy)
91+
const int& incy)
8092
{
8193
cuBlasConnector::gemv(cublas_handle, trans, m, n, *alpha, A, lda, x, incx, *beta, y, incy);
8294
}
@@ -196,6 +208,12 @@ struct blas_gemm_batched_strided<T, DEVICE_GPU> {
196208
};
197209

198210
// Explicitly instantiate functors for the types of functor registered.
211+
212+
template struct blas_copy<float , DEVICE_GPU>;
213+
template struct blas_copy<double, DEVICE_GPU>;
214+
template struct blas_copy<std::complex<float> , DEVICE_GPU>;
215+
template struct blas_copy<std::complex<double>, DEVICE_GPU>;
216+
199217
template struct blas_dot<float , DEVICE_GPU>;
200218
template struct blas_dot<double, DEVICE_GPU>;
201219
template struct blas_dot<std::complex<float> , DEVICE_GPU>;
@@ -242,4 +260,4 @@ template struct blas_gemm_batched_strided<std::complex<float >, DEVICE_GPU>;
242260
template struct blas_gemm_batched_strided<std::complex<double>, DEVICE_GPU>;
243261

244262
} // namespace kernels
245-
} // namespace container
263+
} // namespace container

source/source_base/module_container/ATen/kernels/test/blas_test.cpp

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ class BlasTest : public testing::Test {
2020

2121
TYPED_TEST_SUITE(BlasTest, base::utils::Types);
2222

23+
TYPED_TEST(BlasTest, Copy) {
24+
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
25+
using Device = typename std::tuple_element<1, decltype(TypeParam())>::type;
26+
27+
blas_copy<Type, Device> copyCalculator;
28+
29+
const int n = 3;
30+
const Tensor x = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0)}).to_device<Device>());
31+
Tensor y = std::move(Tensor({static_cast<Type>(0.0), static_cast<Type>(0.0), static_cast<Type>(0.0)}).to_device<Device>());
32+
33+
copyCalculator(n, x.data<Type>(), 1, y.data<Type>(), 1);
34+
const Tensor expected = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0)}).to_device<Device>());
35+
36+
EXPECT_EQ(y, expected);
37+
}
38+
2339
TYPED_TEST(BlasTest, Dot) {
2440
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
2541
using Device = typename std::tuple_element<1, decltype(TypeParam())>::type;
@@ -29,7 +45,7 @@ TYPED_TEST(BlasTest, Dot) {
2945
const int n = 3;
3046
const Tensor x = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0)}).to_device<Device>());
3147
const Tensor y = std::move(Tensor({static_cast<Type>(4.0), static_cast<Type>(5.0), static_cast<Type>(6.0)}).to_device<Device>());
32-
48+
3349
Type result = {};
3450
dotCalculator(n, x.data<Type>(), 1, y.data<Type>(), 1, &result);
3551
const Type expected = static_cast<Type>(32.0);
@@ -46,7 +62,7 @@ TYPED_TEST(BlasTest, Scal) {
4662
const int n = 3;
4763
const Type alpha = static_cast<Type>(2.0);
4864
Tensor x = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0)}).to_device<Device>());
49-
65+
5066
scalCalculator(n, &alpha, x.data<Type>(), 1);
5167
const Tensor expected = std::move(Tensor({static_cast<Type>(2.0), static_cast<Type>(4.0), static_cast<Type>(6.0)}).to_device<Device>());
5268

@@ -64,7 +80,7 @@ TYPED_TEST(BlasTest, Axpy) {
6480
const Type alpha = static_cast<Type>(2.0);
6581
const Tensor x = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0)}).to_device<Device>());
6682
Tensor y = std::move(Tensor({static_cast<Type>(4.0), static_cast<Type>(5.0), static_cast<Type>(6.0)}).to_device<Device>());
67-
83+
6884
axpyCalculator(n, &alpha, x.data<Type>(), 1, y.data<Type>(), 1);
6985
const Tensor expected = std::move(Tensor({static_cast<Type>(6.0), static_cast<Type>(9.0), static_cast<Type>(12.0)}).to_device<Device>());
7086

@@ -83,11 +99,11 @@ TYPED_TEST(BlasTest, Gemv) {
8399
const int n = 2;
84100
const Type alpha = static_cast<Type>(2.0);
85101
const Type beta = static_cast<Type>(3.0);
86-
const Tensor A = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0),
102+
const Tensor A = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0),
87103
static_cast<Type>(4.0), static_cast<Type>(5.0), static_cast<Type>(6.0)}).to_device<Device>());
88104
const Tensor x = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0)}).to_device<Device>());
89105
Tensor y = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0)}).to_device<Device>());
90-
106+
91107
gemvCalculator(trans, m, n, &alpha, A.data<Type>(), m, x.data<Type>(), 1, &beta, y.data<Type>(), 1);
92108
const Tensor expected = std::move(Tensor({static_cast<Type>(21.0), static_cast<Type>(30.0), static_cast<Type>(39.0)}).to_device<Device>());
93109

@@ -114,14 +130,14 @@ TYPED_TEST(BlasTest, GemvBatched) {
114130
std::vector<Type*> y = {};
115131

116132
const Tensor _A = std::move(Tensor({
117-
static_cast<Type>(1.0), static_cast<Type>(2.0),
118-
static_cast<Type>(3.0), static_cast<Type>(4.0),
133+
static_cast<Type>(1.0), static_cast<Type>(2.0),
134+
static_cast<Type>(3.0), static_cast<Type>(4.0),
119135
static_cast<Type>(5.0), static_cast<Type>(6.0),
120-
136+
121137
static_cast<Type>(7.0), static_cast<Type>(8.0),
122138
static_cast<Type>(9.0), static_cast<Type>(10.0),
123139
static_cast<Type>(11.0),static_cast<Type>(12.0)}).to_device<Device>());
124-
140+
125141
A.push_back(_A.data<Type>());
126142
A.push_back(_A.data<Type>() + m * n);
127143

@@ -164,14 +180,14 @@ TYPED_TEST(BlasTest, GemvBatchedStrided) {
164180
std::vector<Type*> y = {};
165181

166182
const Tensor _A = std::move(Tensor({
167-
static_cast<Type>(1.0), static_cast<Type>(2.0),
168-
static_cast<Type>(3.0), static_cast<Type>(4.0),
183+
static_cast<Type>(1.0), static_cast<Type>(2.0),
184+
static_cast<Type>(3.0), static_cast<Type>(4.0),
169185
static_cast<Type>(5.0), static_cast<Type>(6.0),
170-
186+
171187
static_cast<Type>(7.0), static_cast<Type>(8.0),
172188
static_cast<Type>(9.0), static_cast<Type>(10.0),
173189
static_cast<Type>(11.0),static_cast<Type>(12.0)}).to_device<Device>());
174-
190+
175191
A.push_back(_A.data<Type>());
176192
A.push_back(_A.data<Type>() + m * n);
177193

@@ -205,11 +221,11 @@ TYPED_TEST(BlasTest, Gemm) {
205221
const int n = 2;
206222
const Type alpha = static_cast<Type>(2.0);
207223
const Type beta = static_cast<Type>(3.0);
208-
const Tensor A = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0),
224+
const Tensor A = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0),
209225
static_cast<Type>(4.0), static_cast<Type>(5.0), static_cast<Type>(6.0)}).to_device<Device>());
210226
const Tensor x = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0)}).to_device<Device>());
211227
Tensor y = std::move(Tensor({static_cast<Type>(1.0), static_cast<Type>(2.0), static_cast<Type>(3.0)}).to_device<Device>());
212-
228+
213229
gemmCalculator(trans, trans, m, 1, n, &alpha, A.data<Type>(), m, x.data<Type>(), n, &beta, y.data<Type>(), m);
214230
const Tensor expected = std::move(Tensor({static_cast<Type>(21.0), static_cast<Type>(30.0), static_cast<Type>(39.0)}).to_device<Device>());
215231

@@ -237,14 +253,14 @@ TYPED_TEST(BlasTest, GemmBatched) {
237253
std::vector<Type*> y2 = {};
238254

239255
const Tensor _A = std::move(Tensor({
240-
static_cast<Type>(1.0), static_cast<Type>(2.0),
241-
static_cast<Type>(3.0), static_cast<Type>(4.0),
256+
static_cast<Type>(1.0), static_cast<Type>(2.0),
257+
static_cast<Type>(3.0), static_cast<Type>(4.0),
242258
static_cast<Type>(5.0), static_cast<Type>(6.0),
243-
259+
244260
static_cast<Type>(7.0), static_cast<Type>(8.0),
245261
static_cast<Type>(9.0), static_cast<Type>(10.0),
246262
static_cast<Type>(11.0),static_cast<Type>(12.0)}).to_device<Device>());
247-
263+
248264
A.push_back(_A.data<Type>());
249265
A.push_back(_A.data<Type>() + m * n);
250266

@@ -287,14 +303,14 @@ TYPED_TEST(BlasTest, GemmBatchedStrided) {
287303
std::vector<Type*> y2 = {};
288304

289305
const Tensor _A = std::move(Tensor({
290-
static_cast<Type>(1.0), static_cast<Type>(2.0),
291-
static_cast<Type>(3.0), static_cast<Type>(4.0),
306+
static_cast<Type>(1.0), static_cast<Type>(2.0),
307+
static_cast<Type>(3.0), static_cast<Type>(4.0),
292308
static_cast<Type>(5.0), static_cast<Type>(6.0),
293-
309+
294310
static_cast<Type>(7.0), static_cast<Type>(8.0),
295311
static_cast<Type>(9.0), static_cast<Type>(10.0),
296312
static_cast<Type>(11.0),static_cast<Type>(12.0)}).to_device<Device>());
297-
313+
298314
A.push_back(_A.data<Type>());
299315
A.push_back(_A.data<Type>() + m * n);
300316

source/source_base/module_container/base/third_party/blas.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ void daxpy_(const int *N, const double *alpha, const double *x, const int *incx,
2525
void caxpy_(const int *N, const std::complex<float> *alpha, const std::complex<float> *x, const int *incx, std::complex<float> *y, const int *incy);
2626
void zaxpy_(const int *N, const std::complex<double> *alpha, const std::complex<double> *x, const int *incx, std::complex<double> *y, const int *incy);
2727

28+
void scopy_(long const *n, const float *a, int const *incx, float *b, int const *incy);
2829
void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy);
30+
void ccopy_(long const *n, const std::complex<float> *a, int const *incx, std::complex<float> *b, int const *incy);
2931
void zcopy_(long const *n, const std::complex<double> *a, int const *incx, std::complex<double> *b, int const *incy);
3032

3133
//reason for passing results as argument instead of returning it:
3234
//see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/
33-
void cdotc_(const int *n, const std::complex<float> *zx, const int *incx,
35+
void cdotc_(const int *n, const std::complex<float> *zx, const int *incx,
3436
const std::complex<float> *zy, const int *incy, std::complex<float> *result);
35-
void zdotc_(const int *n, const std::complex<double> *zx, const int *incx,
37+
void zdotc_(const int *n, const std::complex<double> *zx, const int *incx,
3638
const std::complex<double> *zy, const int *incy, std::complex<double> *result);
3739
// Peize Lin add ?dot 2017-10-27, to compute d=x*y
3840
float sdot_(const int *N, const float *x, const int *incx, const float *y, const int *incy);
@@ -339,11 +341,21 @@ double nrm2( const int n, const std::complex<double> *x, const int incx )
339341

340342
// copies a into b
341343
static inline
344+
void copy(const long n, const float *a, const int incx, float *b, const int incy)
345+
{
346+
scopy_(&n, a, &incx, b, &incy);
347+
}
348+
static inline
342349
void copy(const long n, const double *a, const int incx, double *b, const int incy)
343350
{
344351
dcopy_(&n, a, &incx, b, &incy);
345352
}
346353
static inline
354+
void copy(const long n, const std::complex<float> *a, const int incx, std::complex<float> *b, const int incy)
355+
{
356+
ccopy_(&n, a, &incx, b, &incy);
357+
}
358+
static inline
347359
void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy)
348360
{
349361
zcopy_(&n, a, &incx, b, &incy);

0 commit comments

Comments
 (0)