@@ -15,6 +15,7 @@ limitations under the License. */
15
15
#pragma once
16
16
17
17
#include < vector>
18
+ #include " paddle/fluid/framework/ddim.h"
18
19
#include " paddle/fluid/framework/eigen.h"
19
20
#include " paddle/fluid/framework/op_registry.h"
20
21
@@ -54,26 +55,45 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
54
55
auto scales = ctx.MultiInput <framework::Tensor>(" Scales" );
55
56
auto * out = ctx.Output <framework::Tensor>(" Out" );
56
57
57
- PADDLE_ENFORCE_EQ (scales[0 ]->numel (), in->dims ()[0 ],
58
- " The number of first scale values must be the same with "
59
- " first dimension value of Input(X)." );
60
-
61
58
auto quant_bits = ctx.Attr <std::vector<int >>(" quant_bits" );
62
59
int max_range = std::pow (2 , quant_bits[0 ] - 1 ) - 1 ;
63
60
64
61
auto & dev_ctx = ctx.template device_context <DeviceContext>();
65
62
out->mutable_data <T>(dev_ctx.GetPlace ());
66
63
67
64
auto dequant = DequantizeFunctor<DeviceContext, T>();
68
- for (int64_t i = 0 ; i < in->dims ()[0 ]; i++) {
69
- framework::Tensor one_channel_in = in->Slice (i, i + 1 );
70
- framework::Tensor one_channel_out = out->Slice (i, i + 1 );
71
- framework::Tensor one_channel_scale = scales[0 ]->Slice (i, i + 1 );
72
- dequant (dev_ctx, &one_channel_in, &one_channel_scale,
73
- static_cast <T>(max_range), &one_channel_out);
74
- }
75
-
76
- if (scales.size () == 2 ) {
65
+ if (scales.size () == 1 ) {
66
+ PADDLE_ENFORCE_EQ (
67
+ scales[0 ]->numel (), in->dims ()[0 ],
68
+ " The number of first scale values must be the same with "
69
+ " first dimension value of Input(X) when the `Scales` has only one "
70
+ " element." );
71
+ for (int64_t i = 0 ; i < in->dims ()[0 ]; i++) {
72
+ framework::Tensor one_channel_in = in->Slice (i, i + 1 );
73
+ framework::Tensor one_channel_out = out->Slice (i, i + 1 );
74
+ framework::Tensor one_channel_scale = scales[0 ]->Slice (i, i + 1 );
75
+ dequant (dev_ctx, &one_channel_in, &one_channel_scale,
76
+ static_cast <T>(max_range), &one_channel_out);
77
+ }
78
+ } else if (scales.size () == 2 ) {
79
+ PADDLE_ENFORCE_EQ (
80
+ scales[0 ]->numel (), in->dims ()[1 ],
81
+ " The number of first scale values must be the same with "
82
+ " second dimension value of Input(X) when the `Scales` has two "
83
+ " elements." );
84
+ for (int64_t i = 0 ; i < in->dims ()[0 ]; i++) {
85
+ framework::Tensor one_batch_in = in->Slice (i, i + 1 ).Resize (
86
+ framework::slice_ddim (in->dims (), 1 , in->dims ().size ()));
87
+ framework::Tensor one_batch_out = out->Slice (i, i + 1 ).Resize (
88
+ framework::slice_ddim (out->dims (), 1 , out->dims ().size ()));
89
+ for (int64_t j = 0 ; j < in->dims ()[1 ]; j++) {
90
+ framework::Tensor one_channel_in = one_batch_in.Slice (j, j + 1 );
91
+ framework::Tensor one_channel_out = one_batch_out.Slice (j, j + 1 );
92
+ framework::Tensor one_channel_scale = scales[0 ]->Slice (j, j + 1 );
93
+ dequant (dev_ctx, &one_channel_in, &one_channel_scale,
94
+ static_cast <T>(max_range), &one_channel_out);
95
+ }
96
+ }
77
97
PADDLE_ENFORCE_EQ (
78
98
scales[1 ]->numel (), 1 ,
79
99
" The second scale tensor should only have one value at now." );
0 commit comments