|
| 1 | +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/operators/group_norm_op.h" |
| 16 | +#include <vector> |
| 17 | +#include "paddle/fluid/operators/npu_op_runner.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | +namespace operators { |
| 21 | + |
| 22 | +using Tensor = framework::Tensor; |
| 23 | + |
| 24 | +template <typename T> |
| 25 | +struct GroupNormFunction { |
| 26 | + public: |
| 27 | + explicit GroupNormFunction(const framework::ExecutionContext& ctx) |
| 28 | + : ctx(ctx) { |
| 29 | + place = ctx.GetPlace(); |
| 30 | + stream = ctx.template device_context<paddle::platform::NPUDeviceContext>() |
| 31 | + .stream(); |
| 32 | + } |
| 33 | + void ReduceMean(const Tensor* x, Tensor* y, const std::vector<int>& dim, |
| 34 | + bool keep_dims = true) { |
| 35 | + // y should be init first |
| 36 | + const auto& runner = NpuOpRunner("ReduceMeanD", {*x}, {*y}, |
| 37 | + {{"axes", dim}, {"keep_dims", keep_dims}}); |
| 38 | + runner.Run(stream); |
| 39 | + } |
| 40 | + void ReduceSum(const Tensor* x, Tensor* y, const std::vector<int>& dim, |
| 41 | + bool keep_dims = true) { |
| 42 | + // y should be init first |
| 43 | + const auto& runner = NpuOpRunner("ReduceSumD", {*x}, {*y}, |
| 44 | + {{"axes", dim}, {"keep_dims", keep_dims}}); |
| 45 | + runner.Run(stream); |
| 46 | + } |
| 47 | + void Add(const Tensor* x, const Tensor* y, Tensor* z) { |
| 48 | + // y should be init first |
| 49 | + const auto& runner = NpuOpRunner("AddV2", {*x, *y}, {*z}, {}); |
| 50 | + runner.Run(stream); |
| 51 | + } |
| 52 | + void Sub(const Tensor* x, const Tensor* y, Tensor* z) { |
| 53 | + // y should be init first |
| 54 | + const auto& runner = NpuOpRunner("Sub", {*x, *y}, {*z}, {}); |
| 55 | + runner.Run(stream); |
| 56 | + } |
| 57 | + void Mul(const Tensor* x, const Tensor* y, Tensor* z) { |
| 58 | + // y should be init first |
| 59 | + const auto& runner = NpuOpRunner("Mul", {*x, *y}, {*z}, {}); |
| 60 | + runner.Run(stream); |
| 61 | + } |
| 62 | + void Div(const Tensor* x, const Tensor* y, Tensor* z) { |
| 63 | + // y should be init first |
| 64 | + const auto& runner = NpuOpRunner("Div", {*x, *y}, {*z}, {}); |
| 65 | + runner.Run(stream); |
| 66 | + } |
| 67 | + void DivNoNan(const Tensor* x, const Tensor* y, Tensor* z) { |
| 68 | + // y should be init first |
| 69 | + const auto& runner = NpuOpRunner("DivNoNan", {*x, *y}, {*z}, {}); |
| 70 | + runner.Run(stream); |
| 71 | + } |
| 72 | + void Transpose(const Tensor* x, Tensor* y, const std::vector<int>& axis) { |
| 73 | + // y should be init first |
| 74 | + const auto& runner = |
| 75 | + NpuOpRunner("TransposeD", {*x}, {*y}, {{"perm", axis}}); |
| 76 | + runner.Run(stream); |
| 77 | + } |
| 78 | + void Sqrt(const Tensor* x, Tensor* y) { |
| 79 | + // y should be init first |
| 80 | + const auto& runner = NpuOpRunner("Sqrt", {*x}, {*y}, {}); |
| 81 | + runner.Run(stream); |
| 82 | + } |
| 83 | + void Adds(const Tensor* x, float scalar, Tensor* y) { |
| 84 | + // y should be init first |
| 85 | + const auto& runner = NpuOpRunner("Adds", {*x}, {*y}, {{"value", scalar}}); |
| 86 | + runner.Run(stream); |
| 87 | + } |
| 88 | + Tensor ReduceMeanToNG(const Tensor* x, const DataLayout& data_layout, |
| 89 | + const int64_t N, const int64_t C, const int64_t H, |
| 90 | + const int64_t W, const int G) { |
| 91 | + Tensor y(x->type()); |
| 92 | + // y.mutable_data<T>( {N,G,1}, place ); |
| 93 | + if (data_layout == DataLayout::kNCHW) { |
| 94 | + y.mutable_data<T>({N, G, 1}, place); |
| 95 | + // shape of x is [N, G, C*H*W/G] |
| 96 | + this->ReduceMean(x, &y, std::vector<int>{2}); |
| 97 | + } else { |
| 98 | + y.mutable_data<T>({N, 1, G}, place); |
| 99 | + // shape of x is [N, C*H*W/G, G] |
| 100 | + Tensor x_trans(x->type()); |
| 101 | + x_trans.mutable_data<T>({N, G, C * H * W / G}, place); |
| 102 | + this->Transpose(x, &x_trans, std::vector<int>{0, 2, 1}); |
| 103 | + this->ReduceMean(&x_trans, &y, std::vector<int>{2}); |
| 104 | + } |
| 105 | + return y; |
| 106 | + } |
| 107 | + |
| 108 | + private: |
| 109 | + platform::Place place; |
| 110 | + aclrtStream stream; |
| 111 | + const framework::ExecutionContext& ctx; |
| 112 | +}; |
| 113 | + |
| 114 | +template <typename T> |
| 115 | +class GroupNormNPUKernel : public framework::OpKernel<T> { |
| 116 | + public: |
| 117 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 118 | + const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); |
| 119 | + const DataLayout data_layout = |
| 120 | + framework::StringToDataLayout(data_layout_str); |
| 121 | + const float epsilon = ctx.Attr<float>("epsilon"); |
| 122 | + auto* scale = ctx.Input<Tensor>("Scale"); |
| 123 | + auto* bias = ctx.Input<Tensor>("Bias"); |
| 124 | + auto* x = ctx.Input<Tensor>("X"); |
| 125 | + |
| 126 | + auto* y = ctx.Output<Tensor>("Y"); |
| 127 | + auto* mean = ctx.Output<Tensor>("Mean"); |
| 128 | + auto* var = ctx.Output<Tensor>("Variance"); |
| 129 | + const auto groups = ctx.Attr<int>("groups"); |
| 130 | + |
| 131 | + auto place = ctx.GetPlace(); |
| 132 | + Tensor xnorm(x->type()); |
| 133 | + xnorm.mutable_data<T>(x->dims(), place); |
| 134 | + GroupNormFunction<T> F(ctx); |
| 135 | + if (data_layout != DataLayout::kNCHW) { |
| 136 | + xnorm.Resize({x->dims()[0], x->dims()[3], x->dims()[1], x->dims()[2]}); |
| 137 | + F.Transpose(x, &xnorm, std::vector<int>{0, 3, 1, 2}); |
| 138 | + } else { |
| 139 | + TensorCopy(*x, platform::NPUPlace(), &xnorm); |
| 140 | + } |
| 141 | + auto N = xnorm.dims()[0]; |
| 142 | + auto C = xnorm.dims()[1]; |
| 143 | + auto H = xnorm.dims()[2]; |
| 144 | + auto W = xnorm.dims()[3]; |
| 145 | + xnorm.Resize({N * groups, C * H * W / groups}); |
| 146 | + std::vector<int> axis = {1}; |
| 147 | + auto reduce_dim = mean->dims(); |
| 148 | + |
| 149 | + mean->mutable_data<T>({N * groups, 1}, place); |
| 150 | + var->mutable_data<T>({N * groups, 1}, place); |
| 151 | + y->mutable_data<T>(place); |
| 152 | + F.ReduceMean(&xnorm, mean, axis); |
| 153 | + |
| 154 | + F.Sub(&xnorm, mean, &xnorm); |
| 155 | + Tensor sqr(x->type()); |
| 156 | + sqr.mutable_data<T>(xnorm.dims(), place); |
| 157 | + |
| 158 | + F.Mul(&xnorm, &xnorm, &sqr); |
| 159 | + F.ReduceMean(&sqr, var, axis); |
| 160 | + Tensor std(x->type()); |
| 161 | + std.mutable_data<T>(var->dims(), place); |
| 162 | + F.Adds(var, epsilon, &std); |
| 163 | + F.Sqrt(&std, &std); |
| 164 | + y->Resize(xnorm.dims()); |
| 165 | + F.Div(&xnorm, &std, y); |
| 166 | + y->Resize({N, C, H, W}); |
| 167 | + if (scale) { |
| 168 | + Tensor scale_t(scale->type()); |
| 169 | + scale_t.ShareDataWith(*scale); |
| 170 | + scale_t.Resize({C, 1, 1}); |
| 171 | + F.Mul(y, &scale_t, y); |
| 172 | + } |
| 173 | + if (bias) { |
| 174 | + Tensor bias_t(bias->type()); |
| 175 | + bias_t.ShareDataWith(*bias); |
| 176 | + bias_t.Resize({C, 1, 1}); |
| 177 | + F.Add(y, &bias_t, y); |
| 178 | + } |
| 179 | + if (data_layout != DataLayout::kNCHW) { |
| 180 | + F.Transpose(y, y, std::vector<int>{0, 2, 3, 1}); |
| 181 | + y->Resize({x->dims()}); |
| 182 | + } |
| 183 | + mean->Resize(reduce_dim); |
| 184 | + var->Resize(reduce_dim); |
| 185 | + } |
| 186 | +}; |
| 187 | + |
| 188 | +template <typename T> |
| 189 | +class GroupNormGradNPUKernel : public framework::OpKernel<T> { |
| 190 | + public: |
| 191 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 192 | + const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); |
| 193 | + const DataLayout data_layout = |
| 194 | + framework::StringToDataLayout(data_layout_str); |
| 195 | + const float epsilon = ctx.Attr<float>("epsilon"); |
| 196 | + auto* y = ctx.Input<Tensor>("Y"); |
| 197 | + auto* var = ctx.Input<Tensor>("Variance"); |
| 198 | + |
| 199 | + auto* scale = ctx.Input<Tensor>("Scale"); |
| 200 | + auto* bias = ctx.Input<Tensor>("Bias"); |
| 201 | + auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); |
| 202 | + const auto G = ctx.Attr<int>("groups"); |
| 203 | + |
| 204 | + // init output |
| 205 | + auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X")); |
| 206 | + auto* d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); |
| 207 | + auto* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); |
| 208 | + |
| 209 | + GroupNormFunction<T> F(ctx); |
| 210 | + auto place = ctx.GetPlace(); |
| 211 | + auto _type = y->type(); |
| 212 | + |
| 213 | + Tensor xnorm(_type); |
| 214 | + xnorm.mutable_data<T>(y->dims(), place); |
| 215 | + Tensor scale_share(_type); |
| 216 | + scale_share.ShareDataWith(*scale); |
| 217 | + Tensor bias_share(_type); |
| 218 | + bias_share.ShareDataWith(*bias); |
| 219 | + |
| 220 | + int64_t N = y->dims()[0]; |
| 221 | + int64_t C, H, W; |
| 222 | + framework::DDim scale_bias_dim; |
| 223 | + if (data_layout == DataLayout::kNCHW) { |
| 224 | + C = y->dims()[1]; |
| 225 | + H = y->dims()[2]; |
| 226 | + W = y->dims()[3]; |
| 227 | + scale_bias_dim = framework::make_ddim({C, 1, 1}); |
| 228 | + } else { |
| 229 | + C = y->dims()[3]; |
| 230 | + H = y->dims()[1]; |
| 231 | + W = y->dims()[2]; |
| 232 | + scale_bias_dim = framework::make_ddim({1, 1, C}); |
| 233 | + } |
| 234 | + scale_share.Resize(scale_bias_dim); |
| 235 | + bias_share.Resize(scale_bias_dim); |
| 236 | + F.Sub(y, &bias_share, &xnorm); |
| 237 | + F.DivNoNan(&xnorm, &scale_share, &xnorm); |
| 238 | + |
| 239 | + if (d_bias) { |
| 240 | + d_bias->mutable_data<T>(place); |
| 241 | + if (data_layout == DataLayout::kNCHW) { |
| 242 | + F.ReduceSum(d_y, d_bias, std::vector<int>{0, 2, 3}, false); |
| 243 | + } else { |
| 244 | + F.ReduceSum(d_y, d_bias, std::vector<int>{0, 1, 2}, false); |
| 245 | + } |
| 246 | + } |
| 247 | + if (d_scale) { |
| 248 | + d_scale->mutable_data<T>(place); |
| 249 | + Tensor dy_xnorm(_type); |
| 250 | + dy_xnorm.mutable_data<T>(d_y->dims(), place); |
| 251 | + F.Mul(d_y, &xnorm, &dy_xnorm); |
| 252 | + if (data_layout == DataLayout::kNCHW) { |
| 253 | + F.ReduceSum(&dy_xnorm, d_scale, std::vector<int>{0, 2, 3}); |
| 254 | + } else { |
| 255 | + F.ReduceSum(&dy_xnorm, d_scale, std::vector<int>{0, 1, 2}); |
| 256 | + } |
| 257 | + } |
| 258 | + |
| 259 | + // std = Sqrt(var+epsilon), init shape = [ N, G ] |
| 260 | + Tensor std(_type); |
| 261 | + std.mutable_data<T>(var->dims(), place); |
| 262 | + F.Adds(var, epsilon, &std); |
| 263 | + F.Sqrt(&std, &std); |
| 264 | + // d_xnorm_std = dy_proc * scale / std |
| 265 | + Tensor d_xnorm_std(_type); |
| 266 | + d_xnorm_std.mutable_data<T>(y->dims(), place); |
| 267 | + F.Mul(d_y, &scale_share, &d_xnorm_std); |
| 268 | + if (data_layout == DataLayout::kNCHW) { |
| 269 | + xnorm.Resize({N, G, C * H * W / G}); |
| 270 | + d_xnorm_std.Resize({N, G, C * H * W / G}); |
| 271 | + std.Resize({N, G, 1}); |
| 272 | + } else { |
| 273 | + xnorm.Resize({N, C * H * W / G, G}); |
| 274 | + d_xnorm_std.Resize({N, C * H * W / G, G}); |
| 275 | + std.Resize({N, 1, G}); |
| 276 | + } |
| 277 | + F.Div(&d_xnorm_std, &std, &d_xnorm_std); |
| 278 | + |
| 279 | + // d_x = d_xnorm_std |
| 280 | + // - Mean ( d_xnorm_std * x_norm, axis=1, keepdim=True ) * x_norm |
| 281 | + // - Mean ( d_xnorm_std, axis=1, keepdim=True ) |
| 282 | + d_x->mutable_data<T>(place); |
| 283 | + d_x->Resize(xnorm.dims()); |
| 284 | + F.Mul(&d_xnorm_std, &xnorm, d_x); |
| 285 | + Tensor dx1 = F.ReduceMeanToNG(d_x, data_layout, N, C, H, W, G); |
| 286 | + F.Mul(&dx1, &xnorm, d_x); |
| 287 | + |
| 288 | + Tensor dx2 = F.ReduceMeanToNG(&d_xnorm_std, data_layout, N, C, H, W, G); |
| 289 | + |
| 290 | + F.Sub(&d_xnorm_std, d_x, d_x); |
| 291 | + F.Sub(d_x, &dx2, d_x); |
| 292 | + |
| 293 | + d_x->Resize(y->dims()); |
| 294 | + } |
| 295 | +}; |
| 296 | + |
| 297 | +} // namespace operators |
| 298 | +} // namespace paddle |
| 299 | + |
| 300 | +namespace ops = paddle::operators; |
| 301 | +namespace plat = paddle::platform; |
| 302 | + |
| 303 | +REGISTER_OP_NPU_KERNEL(group_norm, ops::GroupNormNPUKernel<float>, |
| 304 | + ops::GroupNormNPUKernel<plat::float16>); |
| 305 | +REGISTER_OP_NPU_KERNEL(group_norm_grad, ops::GroupNormGradNPUKernel<float>, |
| 306 | + ops::GroupNormGradNPUKernel<plat::float16>); |
0 commit comments