Skip to content

Commit 627d556

Browse files
authored
cherry pick remove pow to speed up in dequantize_log op (#24607) (#24723)
* remove pow in speed up in dequantize_log test=develop * remove pow in speed up in dequantize_log test=develop * fix unittest test=develop
1 parent b84fedf commit 627d556

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

paddle/fluid/operators/dequantize_log_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
3131
int ind = in->numel();
3232
for (size_t i = 0; i < (unsigned)ind; i++) {
3333
if (input_data[i] < 0) {
34-
output_data[i] = -std::pow(2.0, dict_data[input_data[i] + 128]);
34+
output_data[i] = -dict_data[input_data[i] + 128];
3535
} else {
36-
output_data[i] = std::pow(2.0, dict_data[input_data[i]]);
36+
output_data[i] = dict_data[input_data[i]];
3737
}
3838
}
3939
}

paddle/fluid/operators/dequantize_log_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ __global__ void KeDequantize(const T* in, const float* dict, int num,
2626
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
2727
if (idx < num) {
2828
if (in[idx] < 0) {
29-
out[idx] = -std::pow(static_cast<float>(2.0), dict[in[idx] + 128]);
29+
out[idx] = -dict[in[idx] + 128];
3030
} else {
31-
out[idx] = std::pow(static_cast<float>(2.0), dict[in[idx]]);
31+
out[idx] = dict[in[idx]];
3232
}
3333
}
3434
}

python/paddle/fluid/tests/unittests/test_dequantize_log_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def dequantize_log(x, dict_data):
2626
output_data_f = output_data.flatten()
2727
for i in range(x_f.size):
2828
if x_f[i] < 0:
29-
output_data_f[i] = -np.power(2, dict_data[x_f[i] + 128])
29+
output_data_f[i] = -dict_data[x_f[i] + 128]
3030
else:
31-
output_data_f[i] = np.power(2, dict_data[x_f[i]])
31+
output_data_f[i] = dict_data[x_f[i]]
3232
return output_data_f.reshape(x.shape)
3333

3434

0 commit comments

Comments
 (0)