Skip to content

Commit 8257d25

Browse files
committed
[MLU] add bilinear and bilinear_grad
1 parent 6057ef4 commit 8257d25

File tree

1 file changed

+386
-0
lines changed

1 file changed

+386
-0
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "kernels/funcs/elementwise_utils.h"
16+
#include "kernels/funcs/mlu_baseop.h"
17+
#include "kernels/funcs/mlu_funcs.h"
18+
#include "kernels/funcs/reduce_op.h"
19+
#include "paddle/phi/kernels/funcs/slice_utils.h"
20+
21+
namespace custom_kernel {
22+
23+
template <typename T, typename Context>
24+
void SetTensorValueKernel(const Context& dev_ctx,
25+
const phi::DenseTensor& x,
26+
const phi::DenseTensor& value,
27+
const phi::IntArray& starts,
28+
const phi::IntArray& ends,
29+
const phi::IntArray& steps,
30+
const std::vector<int64_t>& axes,
31+
const std::vector<int64_t>& decrease_axes,
32+
const std::vector<int64_t>& none_axes,
33+
phi::DenseTensor* out);
34+
35+
/*template <typename T, typename Context>
36+
void StridedSliceOutDims(const std::vector<int64_t>& starts,
37+
const std::vector<int64_t>& ends,
38+
const std::vector<int64_t>& strides,
39+
const std::vector<int>& axes,
40+
const std::vector<int>& infer_flags,
41+
const phi::DDim in_dims,
42+
const std::vector<int>& decrease_axis,
43+
int64_t* out_dims_vector,
44+
const size_t size,
45+
bool infer_shape);
46+
47+
template <typename T, typename Context>
48+
void StridedSliceFunctor(int64_t* starts,
49+
int64_t* ends,
50+
int64_t* strides,
51+
const int* axes,
52+
int* reverse_axis,
53+
const phi::DDim dims,
54+
const std::vector<int>& infer_flags,
55+
const std::vector<int>& decrease_axis,
56+
const size_t size);
57+
58+
template <typename T, typename Context, size_t D>
59+
void StridedSliceCompute(const Context& dev_ctx,
60+
const phi::DenseTensor& x,
61+
const std::vector<int>& axes,
62+
const phi::IntArray& starts_array,
63+
const phi::IntArray& ends_array,
64+
const phi::IntArray& strides_array,
65+
const std::vector<int>& infer_flags,
66+
const std::vector<int>& decrease_axis,
67+
phi::DenseTensor* out);*/
68+
69+
template <typename T, typename Context>
70+
void StridedSliceRawKernel(const Context& dev_ctx,
71+
const phi::DenseTensor& x,
72+
const std::vector<int>& axes,
73+
const phi::IntArray& starts,
74+
const phi::IntArray& ends,
75+
const phi::IntArray& strides,
76+
const std::vector<int>& infer_flags,
77+
const std::vector<int>& decrease_axis,
78+
phi::DenseTensor* out);
79+
80+
template <typename T, typename Context>
81+
void BilinearKernel(const Context& dev_ctx,
82+
const phi::DenseTensor& x,
83+
const phi::DenseTensor& y,
84+
const phi::DenseTensor& weight,
85+
const paddle::optional<phi::DenseTensor>& bias,
86+
phi::DenseTensor* out) {
87+
dev_ctx.template Alloc<T>(out);
88+
89+
auto batch_size = x.dims()[0];
90+
auto weight_dims = weight.dims();
91+
int out_dim = weight_dims[0];
92+
auto x_dim = weight_dims[1];
93+
auto y_dim = weight_dims[2];
94+
95+
// Create the intermediate variable to calculate the result of
96+
// Input(X) multiplied by Input(Weight_i), the formula is:
97+
// left_mul = X Weight_i.
98+
Tensor left_mul;
99+
left_mul.Resize(phi::make_ddim({batch_size, y_dim}));
100+
dev_ctx.template Alloc<T>(&left_mul);
101+
102+
MLUCnnlTensorDesc x_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
103+
MLUCnnlTensorDesc y_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
104+
MLUCnnlTensorDesc weight_desc(weight, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
105+
MLUCnnlTensorDesc left_mul_desc(
106+
left_mul, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
107+
108+
phi::DenseTensor output_mat_slice;
109+
output_mat_slice.Resize(phi::make_ddim({batch_size}));
110+
111+
phi::DenseTensor out_temp;
112+
out_temp.Resize(out->dims());
113+
dev_ctx.template Alloc<T>(&out_temp);
114+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), &out_temp);
115+
116+
for (int64_t i = 0; i < out_dim; ++i) {
117+
phi::DenseTensor weight_slice;
118+
weight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
119+
dev_ctx.template Alloc<T>(&weight_slice);
120+
MLUCnnlTensorDesc weight_slice_desc(weight_slice);
121+
122+
phi::DenseTensor matmul_out;
123+
matmul_out.Resize(phi::make_ddim({batch_size, y_dim}));
124+
dev_ctx.template Alloc<T>(&matmul_out);
125+
MLUCnnlTensorDesc matmul_out_desc(matmul_out);
126+
int64_t next_i = i + 1;
127+
int64_t value = 1;
128+
const phi::IntArray& starts_indices = {i};
129+
const phi::IntArray& ends_indices = {next_i};
130+
const phi::IntArray& strides_indices = {value};
131+
std::vector<int> infer_flags(1);
132+
std::vector<int> decrease_axis;
133+
std::vector<int> axes = {0};
134+
custom_kernel::StridedSliceRawKernel<T, Context>(dev_ctx,
135+
weight,
136+
axes,
137+
starts_indices,
138+
ends_indices,
139+
strides_indices,
140+
infer_flags,
141+
decrease_axis,
142+
&weight_slice);
143+
144+
MLUCnnl::Matmul(dev_ctx,
145+
false,
146+
false,
147+
x_desc.get(),
148+
GetBasePtr(&x),
149+
weight_slice_desc.get(),
150+
GetBasePtr(&weight_slice),
151+
left_mul_desc.get(),
152+
GetBasePtr(&left_mul));
153+
154+
int axis = -1;
155+
MLUOpTensorKernel<T>(
156+
dev_ctx, left_mul, y, axis, CNNL_OP_TENSOR_MUL, &matmul_out);
157+
158+
phi::DenseTensor sum_out;
159+
sum_out.Resize({batch_size});
160+
const std::vector<int64_t>& dims = {1};
161+
MLUReduceOp<T>(dev_ctx,
162+
matmul_out,
163+
dims,
164+
false,
165+
/*keep_dim*/ false,
166+
/*reduce_all*/ "reduce_sum",
167+
&sum_out);
168+
169+
std::vector<int64_t> sum_axes = {1};
170+
std::vector<int64_t> decrease_axes;
171+
std::vector<int64_t> none_axes;
172+
custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
173+
*&out_temp,
174+
sum_out,
175+
starts_indices,
176+
ends_indices,
177+
strides_indices,
178+
sum_axes,
179+
decrease_axes,
180+
none_axes,
181+
&output_mat_slice);
182+
}
183+
184+
if (bias.get_ptr()) {
185+
phi::DenseTensor new_bias;
186+
new_bias = bias.get();
187+
int axis = -1;
188+
MLUOpTensorKernel<T>(
189+
dev_ctx, out_temp, new_bias, axis, CNNL_OP_TENSOR_ADD, out);
190+
} else {
191+
TensorCopy(dev_ctx, out_temp, false, out);
192+
}
193+
}
194+
195+
template <typename T, typename Context>
196+
void BilinearGradKernel(const Context& dev_ctx,
197+
const phi::DenseTensor& x,
198+
const phi::DenseTensor& y,
199+
const phi::DenseTensor& weight,
200+
const phi::DenseTensor& dout,
201+
phi::DenseTensor* dx,
202+
phi::DenseTensor* dy,
203+
phi::DenseTensor* dweight,
204+
phi::DenseTensor* dbias) {
205+
auto batch_size = x.dims()[0];
206+
auto weight_dims = weight.dims();
207+
int out_dim = weight_dims[0];
208+
auto x_dim = weight_dims[1];
209+
auto y_dim = weight_dims[2];
210+
211+
// Create the intermediate variable to calculate the Output(Y@Grad).
212+
phi::DenseTensor x_scale;
213+
x_scale.Resize(phi::make_ddim({batch_size, x_dim}));
214+
dev_ctx.template Alloc<T>(&x_scale);
215+
216+
// Create the intermediate variable to calculate the Output(X@Grad).
217+
phi::DenseTensor y_scale;
218+
y_scale.Resize(phi::make_ddim({batch_size, y_dim}));
219+
dev_ctx.template Alloc<T>(&y_scale);
220+
221+
if (dx) {
222+
dev_ctx.template Alloc<T>(dx);
223+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), dx);
224+
}
225+
if (dy) {
226+
dev_ctx.template Alloc<T>(dy);
227+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), dy);
228+
}
229+
if (dweight) {
230+
dev_ctx.template Alloc<T>(dweight);
231+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), dweight);
232+
}
233+
234+
if (dx || dy || dweight) {
235+
phi::DenseTensor dx_temp;
236+
dx_temp.Resize(dx->dims());
237+
dev_ctx.template Alloc<T>(&dx_temp);
238+
MLUCnnlTensorDesc dx_temp_desc(dx_temp);
239+
240+
phi::DenseTensor dy_temp;
241+
dy_temp.Resize(dy->dims());
242+
dev_ctx.template Alloc<T>(&dy_temp);
243+
MLUCnnlTensorDesc dy_temp_desc(dy_temp);
244+
245+
phi::DenseTensor dweight_temp;
246+
dweight_temp.Resize(phi::make_ddim({x_dim, y_dim}));
247+
dev_ctx.template Alloc<T>(&dweight_temp);
248+
MLUCnnlTensorDesc dweight_temp_desc(dweight_temp);
249+
250+
for (int64_t i = 0; i < out_dim; ++i) {
251+
phi::DenseTensor weight_slice;
252+
weight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
253+
dev_ctx.template Alloc<T>(&weight_slice);
254+
int64_t next_i = i + 1;
255+
int64_t value = 1;
256+
const phi::IntArray& starts_indices = {i};
257+
const phi::IntArray& ends_indices = {next_i};
258+
const phi::IntArray& strides_indices = {value};
259+
std::vector<int> infer_flags(1);
260+
std::vector<int> decrease_axis;
261+
std::vector<int> axes = {0};
262+
custom_kernel::StridedSliceRawKernel<T, Context>(dev_ctx,
263+
weight,
264+
axes,
265+
starts_indices,
266+
ends_indices,
267+
strides_indices,
268+
infer_flags,
269+
decrease_axis,
270+
&weight_slice);
271+
weight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
272+
MLUCnnlTensorDesc weight_slice_desc(weight_slice);
273+
MLUCnnlTensorDesc x_scale_desc(x_scale);
274+
MLUCnnlTensorDesc y_scale_desc(y_scale);
275+
MLUCnnlTensorDesc dx_desc(*dx);
276+
MLUCnnlTensorDesc dy_desc(*dy);
277+
MLUCnnlTensorDesc y_desc(y);
278+
279+
// dout[:, i]
280+
std::vector<int> dout_axes = {1};
281+
std::vector<int> decrease_axes;
282+
phi::DenseTensor dout_mat_slice;
283+
dout_mat_slice.Resize(phi::make_ddim({batch_size}));
284+
custom_kernel::StridedSliceRawKernel<T, Context>(dev_ctx,
285+
dout,
286+
dout_axes,
287+
starts_indices,
288+
ends_indices,
289+
strides_indices,
290+
infer_flags,
291+
decrease_axis,
292+
&dout_mat_slice);
293+
if (dx) {
294+
int axis = -1;
295+
dout_mat_slice.Resize({batch_size, 1});
296+
MLUCnnlTensorDesc dout_mat_slice_desc(dout_mat_slice);
297+
MLUOpTensorKernel<T>(
298+
dev_ctx, dout_mat_slice, y, axis, CNNL_OP_TENSOR_MUL, &y_scale);
299+
MLUCnnl::Matmul(dev_ctx,
300+
false,
301+
true,
302+
y_scale_desc.get(),
303+
GetBasePtr(&y_scale),
304+
weight_slice_desc.get(),
305+
GetBasePtr(&weight_slice),
306+
dx_temp_desc.get(),
307+
GetBasePtr(&dx_temp));
308+
MLUOpTensorKernel<T>(
309+
dev_ctx, dx_temp, *dx, axis, CNNL_OP_TENSOR_ADD, dx);
310+
}
311+
if (dy || dweight) {
312+
int axis = -1;
313+
dout_mat_slice.Resize({batch_size, 1});
314+
MLUCnnlTensorDesc dout_mat_slice_desc(dout_mat_slice);
315+
MLUOpTensorKernel<T>(
316+
dev_ctx, dout_mat_slice, x, axis, CNNL_OP_TENSOR_MUL, &x_scale);
317+
if (dy) {
318+
MLUCnnl::Matmul(dev_ctx,
319+
false,
320+
false,
321+
x_scale_desc.get(),
322+
GetBasePtr(&x_scale),
323+
weight_slice_desc.get(),
324+
GetBasePtr(&weight_slice),
325+
dy_temp_desc.get(),
326+
GetBasePtr(&dy_temp));
327+
MLUOpTensorKernel<T>(
328+
dev_ctx, dy_temp, *dy, axis, CNNL_OP_TENSOR_ADD, dy);
329+
}
330+
if (dweight) {
331+
MLUCnnl::Matmul(dev_ctx,
332+
true,
333+
false,
334+
x_scale_desc.get(),
335+
GetBasePtr(&x_scale),
336+
y_desc.get(),
337+
GetBasePtr(&y),
338+
dweight_temp_desc.get(),
339+
GetBasePtr(&dweight_temp));
340+
341+
std::vector<int64_t> dweight_axes = {0};
342+
std::vector<int64_t> decrease_axes;
343+
std::vector<int64_t> none_axes;
344+
phi::DenseTensor dweight_slice;
345+
dweight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
346+
dev_ctx.template Alloc<T>(&dweight_slice);
347+
MLUCnnlTensorDesc dweight_slice_desc(dweight_slice);
348+
custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
349+
*dweight,
350+
dweight_temp,
351+
starts_indices,
352+
ends_indices,
353+
strides_indices,
354+
dweight_axes,
355+
decrease_axes,
356+
none_axes,
357+
&dweight_slice);
358+
}
359+
}
360+
}
361+
// calculate the gradient of Input(Bias).
362+
if (dbias) {
363+
dev_ctx.template Alloc<T>(dbias);
364+
const std::vector<int64_t>& dims = {0};
365+
MLUReduceOp<T>(dev_ctx,
366+
dout,
367+
dims,
368+
false, /*keep_dim*/
369+
false, /*reduce_all*/
370+
"reduce_sum",
371+
dbias);
372+
}
373+
}
374+
}
375+
376+
} // namespace custom_kernel
377+
378+
PD_REGISTER_PLUGIN_KERNEL(
379+
bilinear, mlu, ALL_LAYOUT, custom_kernel::BilinearKernel, float, double) {}
380+
381+
PD_REGISTER_PLUGIN_KERNEL(bilinear_grad,
382+
mlu,
383+
ALL_LAYOUT,
384+
custom_kernel::BilinearGradKernel,
385+
float,
386+
double) {}

0 commit comments

Comments
 (0)