From 3ac3fc99edb2b435910c9780f671e01ad6f8ed1a Mon Sep 17 00:00:00 2001 From: PeiyuLau Date: Fri, 30 Aug 2024 17:46:56 +0800 Subject: [PATCH] [MLU] add bilinear and bilinear_grad --- backends/mlu/kernels/bilinear_kernel.cc | 352 ++++++++++++++++++++++++ 1 file changed, 352 insertions(+) create mode 100644 backends/mlu/kernels/bilinear_kernel.cc diff --git a/backends/mlu/kernels/bilinear_kernel.cc b/backends/mlu/kernels/bilinear_kernel.cc new file mode 100644 index 00000000000..00b6ab86e12 --- /dev/null +++ b/backends/mlu/kernels/bilinear_kernel.cc @@ -0,0 +1,352 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/funcs/elementwise_utils.h" +#include "kernels/funcs/mlu_baseop.h" +#include "kernels/funcs/mlu_funcs.h" +#include "kernels/funcs/reduce_op.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" + +namespace custom_kernel { + +template +void SetTensorValueKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& value, + const phi::IntArray& starts, + const phi::IntArray& ends, + const phi::IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + phi::DenseTensor* out); + +template +void StridedSliceRawKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& axes, + const phi::IntArray& starts, + const phi::IntArray& ends, + const phi::IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + phi::DenseTensor* out); + +template +void BilinearKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + const phi::DenseTensor& weight, + const paddle::optional& bias, + phi::DenseTensor* out) { + dev_ctx.template Alloc(out); + + auto batch_size = x.dims()[0]; + auto weight_dims = weight.dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + + // Create the intermediate variable to calculate the result of + // Input(X) multiplied by Input(Weight_i), the formula is: + // left_mul = X Weight_i. + Tensor left_mul; + left_mul.Resize(phi::make_ddim({batch_size, y_dim})); + dev_ctx.template Alloc(&left_mul); + + MLUCnnlTensorDesc x_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc weight_desc(weight, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc left_mul_desc( + left_mul, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + phi::DenseTensor output_mat_slice; + output_mat_slice.Resize(phi::make_ddim({batch_size})); + + phi::DenseTensor out_temp; + out_temp.Resize(out->dims()); + dev_ctx.template Alloc(&out_temp); + FillMLUTensorWithHostValue(dev_ctx, static_cast(0.0f), &out_temp); + + for (int64_t i = 0; i < out_dim; ++i) { + phi::DenseTensor weight_slice; + weight_slice.Resize(phi::make_ddim({x_dim, y_dim})); + dev_ctx.template Alloc(&weight_slice); + MLUCnnlTensorDesc weight_slice_desc(weight_slice); + + phi::DenseTensor matmul_out; + matmul_out.Resize(phi::make_ddim({batch_size, y_dim})); + dev_ctx.template Alloc(&matmul_out); + MLUCnnlTensorDesc matmul_out_desc(matmul_out); + int64_t next_i = i + 1; + int64_t value = 1; + const phi::IntArray& starts_indices = {i}; + const phi::IntArray& ends_indices = {next_i}; + const phi::IntArray& strides_indices = {value}; + std::vector infer_flags(1); + std::vector decrease_axis; + std::vector axes = {0}; + custom_kernel::StridedSliceRawKernel(dev_ctx, + weight, + axes, + starts_indices, + ends_indices, + strides_indices, + infer_flags, + decrease_axis, + &weight_slice); + + MLUCnnl::Matmul(dev_ctx, + false, + false, + x_desc.get(), + GetBasePtr(&x), + weight_slice_desc.get(), + GetBasePtr(&weight_slice), + left_mul_desc.get(), + GetBasePtr(&left_mul)); + + int axis = -1; + MLUOpTensorKernel( + dev_ctx, left_mul, y, axis, CNNL_OP_TENSOR_MUL, &matmul_out); + + phi::DenseTensor sum_out; + sum_out.Resize({batch_size}); + const std::vector& dims = {1}; + MLUReduceOp(dev_ctx, + matmul_out, + dims, + false, + /*keep_dim*/ false, + /*reduce_all*/ "reduce_sum", + &sum_out); + + std::vector sum_axes = {1}; + std::vector decrease_axes; + std::vector none_axes; + custom_kernel::SetTensorValueKernel(dev_ctx, + *&out_temp, + sum_out, + starts_indices, + ends_indices, + strides_indices, + sum_axes, + decrease_axes, + none_axes, + &output_mat_slice); + } + + if (bias.get_ptr()) { + phi::DenseTensor new_bias; + new_bias = bias.get(); + int axis = -1; + MLUOpTensorKernel( + dev_ctx, out_temp, new_bias, axis, CNNL_OP_TENSOR_ADD, out); + } else { + TensorCopy(dev_ctx, out_temp, false, out); + } +} + +template +void BilinearGradKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + const phi::DenseTensor& weight, + const phi::DenseTensor& dout, + phi::DenseTensor* dx, + phi::DenseTensor* dy, + phi::DenseTensor* dweight, + phi::DenseTensor* dbias) { + auto batch_size = x.dims()[0]; + auto weight_dims = weight.dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + + // Create the intermediate variable to calculate the Output(Y@Grad). + phi::DenseTensor x_scale; + x_scale.Resize(phi::make_ddim({batch_size, x_dim})); + dev_ctx.template Alloc(&x_scale); + + // Create the intermediate variable to calculate the Output(X@Grad). + phi::DenseTensor y_scale; + y_scale.Resize(phi::make_ddim({batch_size, y_dim})); + dev_ctx.template Alloc(&y_scale); + + if (dx) { + dev_ctx.template Alloc(dx); + FillMLUTensorWithHostValue(dev_ctx, static_cast(0.0f), dx); + } + if (dy) { + dev_ctx.template Alloc(dy); + FillMLUTensorWithHostValue(dev_ctx, static_cast(0.0f), dy); + } + if (dweight) { + dev_ctx.template Alloc(dweight); + FillMLUTensorWithHostValue(dev_ctx, static_cast(0.0f), dweight); + } + + if (dx || dy || dweight) { + phi::DenseTensor dx_temp; + dx_temp.Resize(dx->dims()); + dev_ctx.template Alloc(&dx_temp); + MLUCnnlTensorDesc dx_temp_desc(dx_temp); + + phi::DenseTensor dy_temp; + dy_temp.Resize(dy->dims()); + dev_ctx.template Alloc(&dy_temp); + MLUCnnlTensorDesc dy_temp_desc(dy_temp); + + phi::DenseTensor dweight_temp; + dweight_temp.Resize(phi::make_ddim({x_dim, y_dim})); + dev_ctx.template Alloc(&dweight_temp); + MLUCnnlTensorDesc dweight_temp_desc(dweight_temp); + + for (int64_t i = 0; i < out_dim; ++i) { + phi::DenseTensor weight_slice; + weight_slice.Resize(phi::make_ddim({x_dim, y_dim})); + dev_ctx.template Alloc(&weight_slice); + int64_t next_i = i + 1; + int64_t value = 1; + const phi::IntArray& starts_indices = {i}; + const phi::IntArray& ends_indices = {next_i}; + const phi::IntArray& strides_indices = {value}; + std::vector infer_flags(1); + std::vector decrease_axis; + std::vector axes = {0}; + custom_kernel::StridedSliceRawKernel(dev_ctx, + weight, + axes, + starts_indices, + ends_indices, + strides_indices, + infer_flags, + decrease_axis, + &weight_slice); + weight_slice.Resize(phi::make_ddim({x_dim, y_dim})); + MLUCnnlTensorDesc weight_slice_desc(weight_slice); + MLUCnnlTensorDesc x_scale_desc(x_scale); + MLUCnnlTensorDesc y_scale_desc(y_scale); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnlTensorDesc dy_desc(*dy); + MLUCnnlTensorDesc y_desc(y); + + // dout[:, i] + std::vector dout_axes = {1}; + std::vector decrease_axes; + phi::DenseTensor dout_mat_slice; + dout_mat_slice.Resize(phi::make_ddim({batch_size})); + custom_kernel::StridedSliceRawKernel(dev_ctx, + dout, + dout_axes, + starts_indices, + ends_indices, + strides_indices, + infer_flags, + decrease_axis, + &dout_mat_slice); + if (dx) { + int axis = -1; + dout_mat_slice.Resize({batch_size, 1}); + MLUCnnlTensorDesc dout_mat_slice_desc(dout_mat_slice); + MLUOpTensorKernel( + dev_ctx, dout_mat_slice, y, axis, CNNL_OP_TENSOR_MUL, &y_scale); + MLUCnnl::Matmul(dev_ctx, + false, + true, + y_scale_desc.get(), + GetBasePtr(&y_scale), + weight_slice_desc.get(), + GetBasePtr(&weight_slice), + dx_temp_desc.get(), + GetBasePtr(&dx_temp)); + MLUOpTensorKernel( + dev_ctx, dx_temp, *dx, axis, CNNL_OP_TENSOR_ADD, dx); + } + if (dy || dweight) { + int axis = -1; + dout_mat_slice.Resize({batch_size, 1}); + MLUCnnlTensorDesc dout_mat_slice_desc(dout_mat_slice); + MLUOpTensorKernel( + dev_ctx, dout_mat_slice, x, axis, CNNL_OP_TENSOR_MUL, &x_scale); + if (dy) { + MLUCnnl::Matmul(dev_ctx, + false, + false, + x_scale_desc.get(), + GetBasePtr(&x_scale), + weight_slice_desc.get(), + GetBasePtr(&weight_slice), + dy_temp_desc.get(), + GetBasePtr(&dy_temp)); + MLUOpTensorKernel( + dev_ctx, dy_temp, *dy, axis, CNNL_OP_TENSOR_ADD, dy); + } + if (dweight) { + MLUCnnl::Matmul(dev_ctx, + true, + false, + x_scale_desc.get(), + GetBasePtr(&x_scale), + y_desc.get(), + GetBasePtr(&y), + dweight_temp_desc.get(), + GetBasePtr(&dweight_temp)); + + std::vector dweight_axes = {0}; + std::vector decrease_axes; + std::vector none_axes; + phi::DenseTensor dweight_slice; + dweight_slice.Resize(phi::make_ddim({x_dim, y_dim})); + dev_ctx.template Alloc(&dweight_slice); + MLUCnnlTensorDesc dweight_slice_desc(dweight_slice); + custom_kernel::SetTensorValueKernel(dev_ctx, + *dweight, + dweight_temp, + starts_indices, + ends_indices, + strides_indices, + dweight_axes, + decrease_axes, + none_axes, + &dweight_slice); + } + } + } + // calculate the gradient of Input(Bias). + if (dbias) { + dev_ctx.template Alloc(dbias); + const std::vector& dims = {0}; + MLUReduceOp(dev_ctx, + dout, + dims, + false, /*keep_dim*/ + false, /*reduce_all*/ + "reduce_sum", + dbias); + } + } +} + +} // namespace custom_kernel + +PD_REGISTER_PLUGIN_KERNEL( + bilinear, mlu, ALL_LAYOUT, custom_kernel::BilinearKernel, float, double) {} + +PD_REGISTER_PLUGIN_KERNEL(bilinear_grad, + mlu, + ALL_LAYOUT, + custom_kernel::BilinearGradKernel, + float, + double) {}