Skip to content

Commit 2b0811c

Browse files
committed
refine vadd jitkernel choice
test=develop
1 parent a18c0d4 commit 2b0811c

File tree

4 files changed

+7
-2
lines changed

4 files changed

+7
-2
lines changed

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ std::vector<int> TestSizes() {
9393
template <typename KernelTuples, typename... Args>
9494
struct BenchFunc {
9595
// return this function avg time
96+
// TODO(TJ): clear cache every time
9697
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
9798
for (int i = 0; i < FLAGS_burning; ++i) {
9899
tgt(args...);
@@ -172,6 +173,9 @@ void BenchXYZNKernel() {
172173
RandomVec<T>(d, y_data);
173174
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
174175
y.data<T>(), z_data, d);
176+
// test inplace
177+
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
178+
z_data, d);
175179
}
176180
}
177181

paddle/fluid/operators/jit/gen/blas.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
155155
class name##Creator : public JitCodeCreator<int> { \
156156
public: \
157157
bool UseMe(const int& attr) const override { \
158-
return platform::MayIUse(platform::avx); \
158+
return platform::MayIUse(platform::avx) && attr <= 1024; \
159159
} \
160160
size_t CodeSize(const int& d) const override { \
161161
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \

paddle/fluid/operators/jit/gen/blas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode {
6161
base += "_Vec";
6262
}
6363
base += (with_relu_ ? "_Relu" : "");
64+
base += "_D" + std::to_string(num_);
6465
return base.c_str();
6566
}
6667
void genCode() override;

paddle/fluid/operators/jit/more/mkl/mkl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ bool VMulKernel<float>::UseMe(const int& d) const {
139139

140140
template <>
141141
bool VAddKernel<float>::UseMe(const int& d) const {
142-
return platform::MayIUse(platform::avx512f) && d > 512;
142+
return platform::MayIUse(platform::avx) && d > 512;
143143
}
144144

145145
template <>

0 commit comments

Comments
 (0)