Skip to content

Commit e2d6edd

Browse files
committed
remove ComputeDeprecated
test=develop
1 parent f65ddff commit e2d6edd

File tree

6 files changed

+53
-78
lines changed

6 files changed

+53
-78
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ bool VActJitCode::init(int d, operand_type type) {
178178
if (type == operand_type::relu) {
179179
return ok;
180180
} else {
181+
// TODO(TJ): support more
181182
return ok && d == 8; // only 8 yet
182183
}
183184
}

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -98,42 +98,23 @@ class VAddBiasKernel : public Kernel {
9898
template <typename T>
9999
class VActKernel : public Kernel {
100100
public:
101-
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
101+
void (*Compute)(const T *, T *, int);
102102
};
103103

104104
template <typename T>
105-
class VReluKernel : public VActKernel<T> {
106-
public:
107-
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
108-
void (*Compute)(const T *, T *, int);
109-
};
105+
class VReluKernel : public VActKernel<T> {};
110106

111107
template <typename T>
112-
class VIdentityKernel : public VActKernel<T> {
113-
public:
114-
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
115-
};
108+
class VIdentityKernel : public VActKernel<T> {};
116109

117110
template <typename T>
118-
class VExpKernel : public VActKernel<T> {
119-
public:
120-
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
121-
void (*Compute)(const T *, T *, int);
122-
};
111+
class VExpKernel : public VActKernel<T> {};
123112

124113
template <typename T>
125-
class VSigmoidKernel : public VActKernel<T> {
126-
public:
127-
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
128-
void (*Compute)(const T *, T *, int);
129-
};
114+
class VSigmoidKernel : public VActKernel<T> {};
130115

131116
template <typename T>
132-
class VTanhKernel : public VActKernel<T> {
133-
public:
134-
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
135-
void (*Compute)(const T *, T *, int);
136-
};
117+
class VTanhKernel : public VActKernel<T> {};
137118

138119
template <typename T>
139120
class LSTMKernel : public Kernel {

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ class VReluKernelImpl : public VReluKernel<T> {
346346
public:
347347
JITKERNEL_DECLARE_STATIC_FUNC;
348348
explicit VReluKernelImpl(int d) : VReluKernel<T>() {
349-
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
350349
#ifdef PADDLE_WITH_XBYAK
351350
if (useJIT(d)) {
352351
size_t sz = 96 /* init size */ +
@@ -361,9 +360,6 @@ class VReluKernelImpl : public VReluKernel<T> {
361360

362361
this->Compute = VReluRefer<T>;
363362
}
364-
void ComputeDeprecated(const T* x, T* y) const override {
365-
VReluRefer(x, y, this->num_);
366-
}
367363
#ifdef PADDLE_WITH_XBYAK
368364

369365
private:
@@ -378,22 +374,26 @@ bool VReluKernelImpl<float>::useJIT(int d) {
378374
}
379375
#endif
380376

381-
REGISTER_JITKERNEL(vmul, VMulKernel);
382-
REGISTER_JITKERNEL(vadd, VAddKernel);
383-
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
384-
REGISTER_JITKERNEL(vscal, VScalKernel);
385-
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
386-
REGISTER_JITKERNEL(vrelu, VReluKernel);
377+
template <typename T>
378+
inline void VIdentityRefer(const T* x, T* y, int n) {}
387379

388380
/* An empty JitKernel */
389-
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
381+
template <typename T>
390382
class VIdentityKernelImpl : public VIdentityKernel<T> {
391383
public:
392-
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
393-
void ComputeDeprecated(const T* x, T* y) const override {}
384+
JITKERNEL_DECLARE_STATIC_FUNC;
385+
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
386+
this->Compute = VIdentityRefer<T>;
387+
}
394388
};
395389

396-
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
390+
REGISTER_JITKERNEL(vmul, VMulKernel);
391+
REGISTER_JITKERNEL(vadd, VAddKernel);
392+
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
393+
REGISTER_JITKERNEL(vscal, VScalKernel);
394+
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
395+
REGISTER_JITKERNEL(vrelu, VReluKernel);
396+
REGISTER_JITKERNEL(videntity, VIdentityKernel);
397397

398398
} // namespace jitkernel
399399
} // namespace math

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ namespace jitkernel {
3636
namespace jit = platform::jit;
3737

3838
// TODO(TJ): move refer codes to one file
39+
// Refer code only focus on correctness
3940
template <typename T>
4041
void VExpRefer(const T* x, T* y, int n) {
4142
for (int i = 0; i < n; ++i) {
@@ -67,6 +68,7 @@ void VTanhRefer(const T* x, T* y, int n) {
6768
}
6869

6970
#ifdef PADDLE_WITH_MKLML
71+
// try to use MKL to speedup
7072
template <typename T>
7173
void VExpMKL(const T* x, T* y, int n);
7274

@@ -112,7 +114,6 @@ class VExpKernelImpl : public VExpKernel<T> {
112114
public:
113115
JITKERNEL_DECLARE_STATIC_FUNC;
114116
explicit VExpKernelImpl(int d) : VExpKernel<T>() {
115-
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
116117
#ifdef PADDLE_WITH_XBYAK
117118
if (useJIT(d)) {
118119
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
@@ -130,9 +131,7 @@ class VExpKernelImpl : public VExpKernel<T> {
130131
#endif
131132
this->Compute = VExpRefer<T>;
132133
}
133-
void ComputeDeprecated(const T* x, T* y) const override {
134-
VExpRefer(x, y, this->num_);
135-
}
134+
136135
#ifdef PADDLE_WITH_XBYAK
137136

138137
private:
@@ -166,7 +165,6 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
166165
public:
167166
JITKERNEL_DECLARE_STATIC_FUNC;
168167
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
169-
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
170168
#ifdef PADDLE_WITH_XBYAK
171169
if (useJIT(d)) {
172170
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
@@ -186,9 +184,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
186184
#endif
187185
this->Compute = VSigmoidRefer<T>;
188186
}
189-
void ComputeDeprecated(const T* x, T* y) const override {
190-
VSigmoidRefer(x, y, this->num_);
191-
}
187+
192188
#ifdef PADDLE_WITH_XBYAK
193189

194190
private:
@@ -221,7 +217,6 @@ class VTanhKernelImpl : public VTanhKernel<T> {
221217
public:
222218
JITKERNEL_DECLARE_STATIC_FUNC;
223219
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
224-
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
225220
#ifdef PADDLE_WITH_XBYAK
226221
if (useJIT(d)) {
227222
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
@@ -241,9 +236,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
241236
#endif
242237
this->Compute = VTanhRefer<T>;
243238
}
244-
void ComputeDeprecated(const T* x, T* y) const override {
245-
VTanhRefer(x, y, this->num_);
246-
}
239+
247240
#ifdef PADDLE_WITH_XBYAK
248241

249242
private:

paddle/fluid/operators/math/jit_kernel_rnn.cc

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
175175
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
176176
T* checked) const override {
177177
// gates: W_ch, W_ih, W_fh, W_oh
178-
act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_);
178+
act_gate_d3_->Compute(gates + d_, gates + d_, d3_);
179179

180180
/* C_t = C_t-1 * fgated + cand_gated * igated */
181-
act_cand_d_->ComputeDeprecated(gates, gates);
181+
act_cand_d_->Compute(gates, gates, d_);
182182
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
183183
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
184184
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
185185

186186
/* H_t = act_cell(C_t) * ogated */
187-
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
187+
act_cell_d_->Compute(ct, gates + d2_, d_);
188188
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
189189
}
190190
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
191191
/* C_t = igated * cgated*/
192-
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
193-
act_cand_d_->ComputeDeprecated(gates, gates);
192+
act_gate_d_->Compute(gates + d_, gates + d_, d_);
193+
act_cand_d_->Compute(gates, gates, d_);
194194
vmul_d_->Compute(gates, gates + d_, ct, d_);
195195
/* H_t = act_cell(C_t) * ogated */
196-
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
197-
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
196+
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
197+
act_cell_d_->Compute(ct, gates + d2_, d_);
198198
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
199199
}
200200

@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
292292
vmul_d_->Compute(wp_data, ct_1, checked, d_);
293293
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
294294
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
295-
act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_);
295+
act_gate_d2_->Compute(gates + d_, gates + d_, d2_);
296296
/* C_t = C_t-1 * fgated + cand_gated * igated*/
297-
act_cand_d_->ComputeDeprecated(gates, gates);
297+
act_cand_d_->Compute(gates, gates, d_);
298298
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
299299
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
300300
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
301301
/* get ogated*/
302302
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
303303
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
304-
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
304+
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
305305
/* H_t = act_cell(C_t) * ogated */
306-
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
306+
act_cell_d_->Compute(ct, gates + d2_, d_);
307307
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
308308
}
309309

310310
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
311311
/* C_t = igated * cgated*/
312-
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
313-
act_cand_d_->ComputeDeprecated(gates, gates);
312+
act_gate_d_->Compute(gates + d_, gates + d_, d_);
313+
act_cand_d_->Compute(gates, gates, d_);
314314
vmul_d_->Compute(gates, gates + d_, ct, d_);
315315
/* get outgated, put W_oc * C_t on igated */
316316
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
317317
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
318318
/* H_t = act_cell(C_t) * ogated */
319-
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
320-
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
319+
act_gate_d_->Compute(gates + d3_, gates + d3_, d_);
320+
act_cell_d_->Compute(ct, gates + d2_, d_);
321321
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
322322
}
323323

@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
376376
}
377377

378378
void ComputeH1(T* gates, T* ht) const override {
379-
act_gate_d_->ComputeDeprecated(gates, gates);
380-
act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_);
379+
act_gate_d_->Compute(gates, gates, d_);
380+
act_state_d_->Compute(gates + d2_, gates + d2_, d_);
381381
vmul_d_->Compute(gates, gates + d2_, ht, d_);
382382
}
383383

384384
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
385385
// W: {W_update, W_reset; W_state}
386-
act_gate_d2_->ComputeDeprecated(gates, gates);
386+
act_gate_d2_->Compute(gates, gates, d2_);
387387
vmul_d_->Compute(ht_1, gates + d_, ht, d_);
388388
}
389389

390390
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
391391
T* y = gates + d2_;
392-
act_state_d_->ComputeDeprecated(y, y);
392+
act_state_d_->Compute(y, y, d_);
393393
// out = zt*ht~ + (1-zt)*ht_1
394394
for (int i = 0; i < d_; ++i) {
395395
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ TEST(JitKernel, vexp) {
181181

182182
auto ttgts = GetCurrentUS();
183183
for (int i = 0; i < repeat; ++i) {
184-
// ker->ComputeDeprecated(x_data, ztgt_data);
184+
// ker->Compute(x_data, ztgt_data);
185185
ker->Compute(x_data, ztgt_data, d);
186186
}
187187
auto ttgte = GetCurrentUS();
@@ -345,8 +345,8 @@ void lstm_ctht_ref(
345345
const std::shared_ptr<
346346
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
347347
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
348-
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d);
349-
vtanh_d->ComputeDeprecated(gates, gates);
348+
vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
349+
vtanh_d->Compute(gates, gates, d);
350350
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
351351
const float min = SIGMOID_THRESHOLD_MIN;
352352
const float max = SIGMOID_THRESHOLD_MAX;
@@ -356,7 +356,7 @@ void lstm_ctht_ref(
356356
// H_t = act_cell(C_t) * ogated
357357
float tmp = ct[k] * 2;
358358
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
359-
vexp_1->ComputeDeprecated(&tmp, &tmp);
359+
vexp_1->Compute(&tmp, &tmp, 1);
360360
tmp = 2.f / (1.f + tmp) - 1.f;
361361
ht[k] = tmp * o[k];
362362
}
@@ -374,13 +374,13 @@ void lstm_ctht_better(
374374
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
375375
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
376376
int d2 = d * 2;
377-
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d);
378-
vtanh_d->ComputeDeprecated(gates, gates);
377+
vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
378+
vtanh_d->Compute(gates, gates, d);
379379
vmul_d->Compute(gates, gates + d, gates + d, d);
380380
vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
381381
vadd_d->Compute(gates + d, gates + d2, ct, d);
382382
/* H_t = act_cell(C_t) * ogated */
383-
vtanh_d->ComputeDeprecated(ct, gates + d2);
383+
vtanh_d->Compute(ct, gates + d2, d);
384384
vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
385385
}
386386

@@ -737,7 +737,7 @@ void vaddrelu_better(
737737
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
738738
const float* x, const float* y, float* z, int d) {
739739
vadd->Compute(x, y, z, d);
740-
vrelu->ComputeDeprecated(z, z);
740+
vrelu->Compute(z, z, d);
741741
}
742742

743743
TEST(JitKernel, vaddrelu) {

0 commit comments

Comments
 (0)