@@ -27,10 +27,6 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
27
27
// that avoids modifying the variable in the Scope.
28
28
Tensor filter = *context.Input <Tensor>(" Filter" );
29
29
Tensor* output = context.Output <Tensor>(" Output" );
30
- // Tensor* max_input = context.Output<Tensor>("MaxInput");
31
- // Tensor* max_filter = context.Output<Tensor>("MaxFilter");
32
- // max_input->mutable_data<T>(context.GetPlace());
33
- // max_filter->mutable_data<T>(context.GetPlace());
34
30
output->mutable_data <T>(context.GetPlace ());
35
31
int groups = context.Attr <int >(" groups" );
36
32
std::vector<int > strides = context.Attr <std::vector<int >>(" strides" );
@@ -43,62 +39,25 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
43
39
const int f = static_cast <int >(filter.dims ()[0 ]);
44
40
const int win_h = static_cast <int >(filter.dims ()[2 ]);
45
41
const int win_w = static_cast <int >(filter.dims ()[3 ]);
46
- PADDLE_ENFORCE_EQ (
47
- dilations[0 ] == 1 && dilations[1 ] == 1 , true ,
48
- platform::errors::InvalidArgument (" XPU only support dilation == 1." ));
49
42
auto & dev_ctx = context.template device_context <DeviceContext>();
50
- // PADDLE_ENFORCE_EQ(
51
- // xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
52
- // max_input->data<T>()) == xpu::Error_t::SUCCESS,
53
- // true, platform::errors::InvalidArgument(
54
- // "XPU conv kernel error,can not finde max_input,please "
55
- // "check whether Baidu Kunlun "
56
- // "Card is properly installed."));
57
- // PADDLE_ENFORCE_EQ(
58
- // xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
59
- // max_filter->data<T>()) == xpu::Error_t::SUCCESS,
60
- // true, platform::errors::InvalidArgument(
61
- // "XPU conv kernel error,can not find max_filter,please "
62
- // "check whether Baidu Kunlun "
63
- // "Card is properly installed."));
64
- if (groups == 1 ) {
65
- int r = xpu::conv2d_forward_int16<float , float , float , float >(
66
- dev_ctx.x_context (), batch_size, img_c, img_h, img_w, f, win_h, win_w,
67
- strides[0 ], strides[1 ], paddings[0 ], paddings[1 ], dilations[0 ],
68
- dilations[1 ], groups, input->data <float >(), filter.data <float >(),
69
- output->data <float >(), nullptr , nullptr , xpu::Activation_t::LINEAR,
70
- nullptr , nullptr );
71
- // max_input->data<float>(), max_filter->data<float>());
72
- PADDLE_ENFORCE_EQ (
73
- r, XPU_SUCCESS,
74
- platform::errors::External (" XPU conv kernel return wrong value[%d], "
75
- " please check whether Baidu Kunlun Card "
76
- " is properly installed." ,
77
- r));
78
- } else {
79
- int r = xpu::conv2d_int16_with_group<float , float , float >(
80
- dev_ctx.x_context (), input->data <float >(), filter.data <float >(),
81
- output->data <float >(), batch_size, img_c, img_h, img_w, f, win_h,
82
- win_w, groups, strides[0 ], strides[1 ], paddings[0 ], paddings[1 ],
83
- nullptr , nullptr );
84
- // max_input->data<float>(), max_filter->data<float>());
85
- PADDLE_ENFORCE_EQ (
86
- r, XPU_SUCCESS,
87
- platform::errors::External (" XPU conv kernel return wrong value[%d], "
88
- " please check whether Baidu Kunlun Card "
89
- " is properly installed." ,
90
- r));
91
- }
43
+ std::vector<int > k_size;
44
+ k_size.push_back (win_h);
45
+ k_size.push_back (win_w);
46
+ int r = xpu::conv2d<float , float , float , int16_t >(
47
+ dev_ctx.x_context (), input->data <float >(), filter.data <float >(),
48
+ output->data <float >(), batch_size, img_c, img_h, img_w, f, k_size,
49
+ strides, paddings, dilations, groups, nullptr , nullptr , nullptr , true );
50
+ PADDLE_ENFORCE_EQ (
51
+ r, XPU_SUCCESS,
52
+ platform::errors::External (" XPU conv kernel return wrong value[%d %s]" ,
53
+ r, XPUAPIErrorMsg[r]));
92
54
}
93
55
};
94
56
template <typename DeviceContext, typename T>
95
57
class GemmConvGradXPUKernel : public framework ::OpKernel<T> {
96
58
public:
97
59
void Compute (const framework::ExecutionContext& context) const override {
98
60
const Tensor* input = context.Input <Tensor>(" Input" );
99
- // const Tensor* max_input = context.Input<Tensor>("MaxInput");
100
- // const Tensor* max_filter = context.Input<Tensor>("MaxFilter");
101
- // Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad");
102
61
const Tensor* output_grad =
103
62
context.Input <Tensor>(framework::GradVarName (" Output" ));
104
63
Tensor* input_grad =
@@ -115,11 +74,6 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
115
74
std::vector<int > paddings = context.Attr <std::vector<int >>(" paddings" );
116
75
std::vector<int > dilations = context.Attr <std::vector<int >>(" dilations" );
117
76
const int batch_size = static_cast <int >(input->dims ()[0 ]);
118
- PADDLE_ENFORCE_EQ (groups == 1 , true , platform::errors::InvalidArgument (
119
- " XPU only support groups == 1." ));
120
- PADDLE_ENFORCE_EQ (
121
- dilations[0 ] == 1 && dilations[1 ] == 1 , true ,
122
- platform::errors::InvalidArgument (" XPU only support dilation == 1." ));
123
77
const int img_c = static_cast <int >(input->dims ()[1 ]);
124
78
const int img_h = static_cast <int >(input->dims ()[2 ]);
125
79
const int img_w = static_cast <int >(input->dims ()[3 ]);
@@ -133,52 +87,24 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
133
87
filter_grad->mutable_data <T>(context.GetPlace ());
134
88
}
135
89
auto & dev_ctx = context.template device_context <DeviceContext>();
136
- // max_output_grad->Resize({4});
137
- // max_output_grad->mutable_data<T>(context.GetPlace());
138
- // PADDLE_ENFORCE_EQ(
139
- // xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
140
- // output_grad->numel(),
141
- // max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
142
- // true,
143
- // platform::errors::External(
144
- // "XPU conv kernel error, can not find max_output_grad, please
145
- // check "
146
- // "whether Baidu Kunlun Card is "
147
- // "properly installed."));
148
- if (input_grad) {
149
- int r = xpu::conv2d_backward_int16 (
150
- dev_ctx.x_context (), batch_size, img_c, img_h, img_w, f, win_h, win_w,
151
- strides[0 ], strides[1 ], paddings[0 ], paddings[1 ], dilations[0 ],
152
- dilations[1 ], groups, output_grad->data <float >(),
153
- filter.data <float >(), input_grad->data <float >(), nullptr , nullptr );
154
- // max_output_grad->data<float>(), max_filter->data<float>());
155
- PADDLE_ENFORCE_EQ (
156
- r, XPU_SUCCESS,
157
- platform::errors::External (" XPU conv kernel return wrong value[%d], "
158
- " please check whether Baidu Kunlun Card "
159
- " is properly installed." ,
160
- r));
161
- }
162
- if (filter_grad) {
163
- int r = xpu::conv2d_backward_weight_int16 (
164
- dev_ctx.x_context (), batch_size, img_c, img_h, img_w, f, win_h, win_w,
165
- strides[0 ], strides[1 ], paddings[0 ], paddings[1 ], dilations[0 ],
166
- dilations[1 ], groups, output_grad->data <float >(),
167
- input->data <float >(), filter_grad->data <float >(), nullptr , nullptr );
168
- // max_output_grad->data<float>(), max_input->data<float>());
169
- PADDLE_ENFORCE_EQ (
170
- r, XPU_SUCCESS,
171
- platform::errors::External (" XPU conv kernel return wrong value[%d], "
172
- " please check whether Baidu Kunlun Card "
173
- " is properly installed." ,
174
- r));
175
- }
90
+ std::vector<int > k_size;
91
+ k_size.push_back (win_h);
92
+ k_size.push_back (win_w);
93
+ int r = xpu::conv2d_grad<float , float , float , int16_t >(
94
+ dev_ctx.x_context (), input->data <T>(), filter.data <T>(),
95
+ output_grad->data <T>(), input_grad ? input_grad->data <T>() : nullptr ,
96
+ filter_grad ? filter_grad->data <T>() : nullptr , batch_size, img_c,
97
+ img_h, img_w, f, k_size, strides, paddings, dilations, groups, nullptr ,
98
+ nullptr , nullptr , nullptr , nullptr , true );
99
+ PADDLE_ENFORCE_EQ (
100
+ r, XPU_SUCCESS,
101
+ platform::errors::External (" XPU conv kernel return wrong value[%d %s]" ,
102
+ r, XPUAPIErrorMsg[r]));
176
103
}
177
104
};
178
105
} // namespace operators
179
106
} // namespace paddle
180
107
namespace ops = paddle::operators;
181
- // TODO(xingzhaolong): neon kernel for mobile
182
108
REGISTER_OP_XPU_KERNEL (
183
109
depthwise_conv2d,
184
110
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float >);
@@ -187,4 +113,7 @@ REGISTER_OP_XPU_KERNEL(
187
113
REGISTER_OP_XPU_KERNEL (
188
114
conv2d_grad,
189
115
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float >);
116
+ REGISTER_OP_XPU_KERNEL (
117
+ depthwise_conv2d_grad,
118
+ ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float >);
190
119
#endif
0 commit comments