Skip to content

Commit 43c49ad

Browse files
committed
Add blas_nrm2
1 parent 247cdd2 commit 43c49ad

File tree

7 files changed

+108
-0
lines changed

7 files changed

+108
-0
lines changed

source/source_base/module_container/ATen/kernels/blas.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
namespace container {
44
namespace kernels {
55

6+
67
template <typename T>
78
struct blas_copy<T, DEVICE_CPU> {
89
void operator()(
@@ -16,6 +17,18 @@ struct blas_copy<T, DEVICE_CPU> {
1617
}
1718
};
1819

20+
template <typename T>
21+
struct blas_nrm2<T, DEVICE_CPU> {
22+
using Real = typename GetTypeReal<T>::type;
23+
Real operator()(
24+
const int n,
25+
const T *x,
26+
const int incx)
27+
{
28+
return BlasConnector::nrm2(n, x, incx);
29+
}
30+
};
31+
1932
template <typename T>
2033
struct blas_dot<T, DEVICE_CPU> {
2134
void operator()(
@@ -194,6 +207,11 @@ template struct blas_copy<double, DEVICE_CPU>;
194207
template struct blas_copy<std::complex<float >, DEVICE_CPU>;
195208
template struct blas_copy<std::complex<double>, DEVICE_CPU>;
196209

210+
template struct blas_nrm2<float , DEVICE_CPU>;
211+
template struct blas_nrm2<double, DEVICE_CPU>;
212+
template struct blas_nrm2<std::complex<float >, DEVICE_CPU>;
213+
template struct blas_nrm2<std::complex<double>, DEVICE_CPU>;
214+
197215
template struct blas_dot<float , DEVICE_CPU>;
198216
template struct blas_dot<double, DEVICE_CPU>;
199217
template struct blas_dot<std::complex<float >, DEVICE_CPU>;

source/source_base/module_container/ATen/kernels/blas.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ struct blas_copy {
1919
const int incy);
2020
};
2121

22+
template <typename T, typename Device>
23+
struct blas_nrm2 {
24+
using Real = typename GetTypeReal<T>::type;
25+
Real operator()(
26+
const int n,
27+
const T *x,
28+
const int incx);
29+
};
2230

2331
template <typename T, typename Device>
2432
struct blas_dot {

source/source_base/module_container/ATen/kernels/cuda/blas.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@ void destroyGpuBlasHandle() {
2222
}
2323
}
2424

25+
template <typename T>
26+
struct blas_nrm2<T, DEVICE_GPU> {
27+
using Real = typename GetTypeReal<T>::type;
28+
Real operator()(
29+
const int n,
30+
const T *x,
31+
const int incx)
32+
{
33+
Real result;
34+
cuBlasConnector::nrm2(cublas_handle, n, x, incx, &result);
35+
return result;
36+
}
37+
};
38+
2539
template <typename T>
2640
struct blas_copy<T, DEVICE_GPU> {
2741
void operator()(
@@ -209,11 +223,18 @@ struct blas_gemm_batched_strided<T, DEVICE_GPU> {
209223

210224
// Explicitly instantiate functors for the types of functor registered.
211225

226+
227+
212228
template struct blas_copy<float , DEVICE_GPU>;
213229
template struct blas_copy<double, DEVICE_GPU>;
214230
template struct blas_copy<std::complex<float> , DEVICE_GPU>;
215231
template struct blas_copy<std::complex<double>, DEVICE_GPU>;
216232

233+
template struct blas_nrm2<float , DEVICE_GPU>;
234+
template struct blas_nrm2<double, DEVICE_GPU>;
235+
template struct blas_nrm2<std::complex<float> , DEVICE_GPU>;
236+
template struct blas_nrm2<std::complex<double>, DEVICE_GPU>;
237+
217238
template struct blas_dot<float , DEVICE_GPU>;
218239
template struct blas_dot<double, DEVICE_GPU>;
219240
template struct blas_dot<std::complex<float> , DEVICE_GPU>;

source/source_base/module_container/ATen/kernels/rocm/blas.hip.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ void destroyGpuBlasHandle() {
2323
}
2424

2525

26+
template <typename T>
27+
struct blas_nrm2<T, DEVICE_GPU> {
28+
T operator()(
29+
const int n,
30+
const T *x,
31+
const int incx)
32+
{
33+
T result;
34+
hipBlasConnector::nrm2(hipblas_handle, n, x, incx, &result);
35+
return result;
36+
}
37+
};
38+
2639
template <typename T>
2740
struct blas_dot<T, DEVICE_GPU> {
2841
void operator()(
@@ -196,6 +209,11 @@ struct blas_gemm_batched_strided<T, DEVICE_GPU> {
196209
};
197210

198211
// Explicitly instantiate functors for the types of functor registered.
212+
template struct blas_nrm2<float , DEVICE_GPU>;
213+
template struct blas_nrm2<double, DEVICE_GPU>;
214+
template struct blas_nrm2<std::complex<float> , DEVICE_GPU>;
215+
template struct blas_nrm2<std::complex<double>, DEVICE_GPU>;
216+
199217
template struct blas_dot<float , DEVICE_GPU>;
200218
template struct blas_dot<double, DEVICE_GPU>;
201219
template struct blas_dot<std::complex<float> , DEVICE_GPU>;

source/source_base/module_container/ATen/kernels/test/blas_test.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ TYPED_TEST(BlasTest, Copy) {
3636
EXPECT_EQ(y, expected);
3737
}
3838

39+
TYPED_TEST(BlasTest, Nrm2) {
40+
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
41+
using Device = typename std::tuple_element<1, decltype(TypeParam())>::type;
42+
43+
blas_nrm2<Type, Device> nrm2Calculator;
44+
45+
const int n = 3;
46+
const Tensor x = std::move(Tensor({static_cast<Type>(3.0), static_cast<Type>(4.0), static_cast<Type>(0.0)}).to_device<Device>());
47+
48+
Type result = {};
49+
nrm2Calculator(n, x.data<Type>(), 1, &result);
50+
const Type expected = static_cast<Type>(5.0);
51+
52+
EXPECT_NEAR(result, expected, static_cast<Type>(1e-6));
53+
}
54+
3955
TYPED_TEST(BlasTest, Dot) {
4056
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
4157
using Device = typename std::tuple_element<1, decltype(TypeParam())>::type;

source/source_base/module_container/base/third_party/blas.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ double ddot_(const int *N, const double *x, const int *incx, const double *y, co
4343
// Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 }
4444
float snrm2_( const int *n, const float *x, const int *incx );
4545
double dnrm2_( const int *n, const double *x, const int *incx );
46+
double scnrm2_( const int *n, const std::complex<float> *x, const int *incx );
4647
double dznrm2_( const int *n, const std::complex<double> *x, const int *incx );
4748

4849
// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work.
@@ -334,6 +335,11 @@ double nrm2( const int n, const double *x, const int incx )
334335
return dnrm2_( &n, x, &incx );
335336
}
336337
static inline
338+
double nrm2( const int n, const std::complex<float> *x, const int incx )
339+
{
340+
return scnrm2_( &n, x, &incx );
341+
}
342+
static inline
337343
double nrm2( const int n, const std::complex<double> *x, const int incx )
338344
{
339345
return dznrm2_( &n, x, &incx );

source/source_base/module_container/base/third_party/cublas.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ void copy(cublasHandle_t& handle, const int& n, const std::complex<double> *x, c
2929
cublasErrcheck(cublasZcopy(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<cuDoubleComplex*>(y), incy));
3030
}
3131

32+
static inline
33+
void nrm2(cublasHandle_t& handle, const int& n, const float *x, const int& incx, float* result)
34+
{
35+
cublasErrcheck(cublasSnrm2(handle, n, x, incx, result));
36+
}
37+
static inline
38+
void nrm2(cublasHandle_t& handle, const int& n, const double *x, const int& incx, double* result)
39+
{
40+
cublasErrcheck(cublasDnrm2(handle, n, x, incx, result));
41+
}
42+
static inline
43+
void nrm2(cublasHandle_t& handle, const int& n, const std::complex<float> *x, const int& incx, float* result)
44+
{
45+
cublasErrcheck(cublasScnrm2(handle, n, reinterpret_cast<const cuComplex*>(x), incx, result));
46+
}
47+
static inline
48+
void nrm2(cublasHandle_t& handle, const int& n, const std::complex<double> *x, const int& incx, double* result)
49+
{
50+
cublasErrcheck(cublasDznrm2(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), incx, result));
51+
}
52+
3253
static inline
3354
void dot(cublasHandle_t& handle, const int& n, const float *x, const int& incx, const float *y, const int& incy, float* result)
3455
{

0 commit comments

Comments
 (0)