Skip to content

Commit 64f7516

Browse files
authored
fix lrn on mac (#14426)
* rename and fix blas vsqr test=develop * update
1 parent 8a1eeec commit 64f7516

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

paddle/fluid/operators/lrn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
4646
int pre_pad = (n - 1) / 2;
4747
// compute batches one by one
4848
for (int i = 0; i < N; ++i) {
49-
blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
49+
blas.VSQUARE(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
5050
// init the first channel of mid
5151
for (int c = 0; c < n; ++c) {
5252
blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size);

paddle/fluid/operators/math/blas.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class Blas {
153153
void VEXP(int n, const T* x, T* y) const;
154154

155155
template <typename T>
156-
void VSQR(int n, const T* x, T* y) const;
156+
void VSQUARE(int n, const T* x, T* y) const;
157157

158158
template <typename T>
159159
void VPOW(int n, const T* x, T alpha, T* y) const;
@@ -245,8 +245,8 @@ class BlasT : private Blas<DeviceContext> {
245245
}
246246

247247
template <typename... ARGS>
248-
void VSQR(ARGS... args) const {
249-
Base()->template VSQR<T>(args...);
248+
void VSQUARE(ARGS... args) const {
249+
Base()->template VSQUARE<T>(args...);
250250
}
251251

252252
template <typename... ARGS>

paddle/fluid/operators/math/blas_impl.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct CBlas<float> {
105105
}
106106

107107
template <typename... ARGS>
108-
static void VSQR(ARGS... args) {
108+
static void VSQUARE(ARGS... args) {
109109
platform::dynload::vsSqr(args...);
110110
}
111111

@@ -195,7 +195,7 @@ struct CBlas<double> {
195195
}
196196

197197
template <typename... ARGS>
198-
static void VSQR(ARGS... args) {
198+
static void VSQUARE(ARGS... args) {
199199
platform::dynload::vdSqr(args...);
200200
}
201201

@@ -262,7 +262,9 @@ struct CBlas<platform::float16> {
262262
}
263263
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
264264
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
265-
static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); }
265+
static void VSQUARE(...) {
266+
PADDLE_THROW("float16 VSQUARE not supported on CPU");
267+
}
266268
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
267269
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
268270
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
@@ -423,12 +425,12 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
423425

424426
template <>
425427
template <typename T>
426-
void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const {
428+
void Blas<platform::CPUDeviceContext>::VSQUARE(int n, const T *x, T *y) const {
427429
#ifdef PADDLE_WITH_MKLML
428-
CBlas<T>::VSQR(n, x, y);
430+
CBlas<T>::VSQUARE(n, x, y);
429431
#else
430432
for (int i = 0; i < n; ++i) {
431-
y[i] = std::sqrt(x[i]);
433+
y[i] = x[i] * x[i];
432434
}
433435
#endif
434436
}

0 commit comments

Comments
 (0)