|
| 1 | +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <memory> |
| 16 | +#include <string> |
| 17 | +#include "paddle/fluid/framework/op_registry.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | +namespace operators { |
| 21 | + |
| 22 | +using Tensor = framework::Tensor; |
| 23 | + |
| 24 | +class FusedMultiTransformerOp : public framework::OperatorWithKernel { |
| 25 | + private: |
| 26 | + static constexpr const char *OpName = "FusedMultiTransformerOp"; |
| 27 | + |
| 28 | + public: |
| 29 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 30 | + |
| 31 | + void InferShape(framework::InferShapeContext *ctx) const override { |
| 32 | +#define CHECK_INPUT(name) \ |
| 33 | + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) |
| 34 | +#define CHECK_INPUTS(name) \ |
| 35 | + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) |
| 36 | +#define CHECK_OUTPUT(name) \ |
| 37 | + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) |
| 38 | +#define CHECK_OUTPUTS(name) \ |
| 39 | + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) |
| 40 | + |
| 41 | + CHECK_INPUT(X); |
| 42 | + |
| 43 | + // attention |
| 44 | + CHECK_INPUTS(QKVW); |
| 45 | + CHECK_INPUTS(OutLinearW); |
| 46 | + |
| 47 | + if (ctx->HasInput("TimeStep")) { |
| 48 | + CHECK_INPUTS(CacheKV); |
| 49 | + } |
| 50 | + |
| 51 | + if (ctx->HasInputs("CacheKV")) { |
| 52 | + CHECK_OUTPUTS(CacheKVOut); |
| 53 | + } |
| 54 | + |
| 55 | + // ffn |
| 56 | + CHECK_INPUTS(FFN1Weight); |
| 57 | + CHECK_INPUTS(FFN2Weight); |
| 58 | + |
| 59 | + CHECK_OUTPUT(Out); |
| 60 | + |
| 61 | + // x: qkv's input [batch_size, seq_len, dim_embed] |
| 62 | + // y: qkv's weight: [3, num_head, dim_head, dim_embed] |
| 63 | + auto x_dim = ctx->GetInputDim("X"); |
| 64 | + auto y_dim = ctx->GetInputsDim("QKVW")[0]; |
| 65 | + PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument( |
| 66 | + "The dimensions of x must be 3" |
| 67 | + "(batch_size, seq_len, dim_embed)," |
| 68 | + "but received dimensions of" |
| 69 | + "Input is [%d]", |
| 70 | + x_dim.size())); |
| 71 | + PADDLE_ENFORCE_EQ(y_dim.size(), 4, |
| 72 | + platform::errors::InvalidArgument( |
| 73 | + "The dimensions of qkv_weight must be 4" |
| 74 | + "(3, num_head, dim_head, dim_embed)," |
| 75 | + "but received dimensions of" |
| 76 | + "Input is [%d]", |
| 77 | + y_dim.size())); |
| 78 | + PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3], |
| 79 | + platform::errors::InvalidArgument( |
| 80 | + "ShapeError: the dimension of x_dim[2] and y_dim[3]" |
| 81 | + "must be equal. But received: the shape " |
| 82 | + "of input x = [%s], and the shape of " |
| 83 | + "input qkv_weight = [%s]", |
| 84 | + x_dim, y_dim)); |
| 85 | + |
| 86 | + if (ctx->Attrs().Get<int>("ring_id") == -1) { |
| 87 | + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3], |
| 88 | + platform::errors::InvalidArgument( |
| 89 | + "The dimensions of qkv_weight must be 4" |
| 90 | + "(3, num_head, dim_head, dim_embed)," |
| 91 | + "and must satisfy the limitations: " |
| 92 | + "(num_head * dim_head == dim_embed)")); |
| 93 | + } |
| 94 | + |
| 95 | + if (ctx->HasInputs("CacheKV")) { |
| 96 | + // [2, batch_size, num_head, max_seq_len, head_size] |
| 97 | + const auto &c_dims = ctx->GetInputsDim("CacheKV"); |
| 98 | + const auto &c_dim = c_dims[0]; |
| 99 | + |
| 100 | + PADDLE_ENFORCE_EQ( |
| 101 | + c_dim.size(), 5, |
| 102 | + paddle::platform::errors::InvalidArgument( |
| 103 | + "The CacheKV must be 5 dims, but got %d", c_dim.size())); |
| 104 | + PADDLE_ENFORCE_EQ(c_dim[0], 2, |
| 105 | + paddle::platform::errors::InvalidArgument( |
| 106 | + "The first dim of CacheKV must be 2, but got %d", |
| 107 | + c_dim[0])); // 2 |
| 108 | + PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0], |
| 109 | + paddle::platform::errors::InvalidArgument( |
| 110 | + "The second dim of CacheKV must be equal with " |
| 111 | + "batch size %d, but got %d", |
| 112 | + x_dim[0], c_dim[1])); // batch_size |
| 113 | + PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1], |
| 114 | + paddle::platform::errors::InvalidArgument( |
| 115 | + "The third dim of CacheKV must be equal with num " |
| 116 | + "head %d, but got %d", |
| 117 | + y_dim[1], c_dim[2])); // num_head |
| 118 | + PADDLE_ENFORCE_GT( |
| 119 | + c_dim[3], 0, |
| 120 | + paddle::platform::errors::InvalidArgument( |
| 121 | + "The forth dim of CacheKV must be greater than 0, but got %d", |
| 122 | + c_dim[3])); // cache_seq_len |
| 123 | + PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2], |
| 124 | + paddle::platform::errors::InvalidArgument( |
| 125 | + "The fifth dim of CacheKV must be equal with head " |
| 126 | + "size %d, but got %d", |
| 127 | + y_dim[2], c_dim[4])); // head_size |
| 128 | + } |
| 129 | + |
| 130 | + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); |
| 131 | + } |
| 132 | + |
| 133 | + protected: |
| 134 | + framework::OpKernelType GetExpectedKernelType( |
| 135 | + const framework::ExecutionContext &ctx) const override { |
| 136 | + return framework::OpKernelType( |
| 137 | + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); |
| 138 | + } |
| 139 | + |
| 140 | + framework::OpKernelType GetKernelTypeForVar( |
| 141 | + const std::string &var_name, const Tensor &tensor, |
| 142 | + const framework::OpKernelType &expected_kernel_type) const override { |
| 143 | + if (var_name == "TimeStep") { |
| 144 | + VLOG(10) << "var_name:" << var_name << " need not to transform"; |
| 145 | + return expected_kernel_type; |
| 146 | + } |
| 147 | + return framework::OpKernelType(expected_kernel_type.data_type_, |
| 148 | + tensor.place(), tensor.layout()); |
| 149 | + } |
| 150 | +}; |
| 151 | + |
| 152 | +class FusedMultiTransformerOpOpMaker |
| 153 | + : public framework::OpProtoAndCheckerMaker { |
| 154 | + public: |
| 155 | + void Make() override { |
| 156 | + AddInput("X", "The input tensor."); |
| 157 | + AddInput("LnScale", |
| 158 | + "Scale is a 1-dimensional tensor of size " |
| 159 | + "H. Here, H represents the last dimension of its input tensor.") |
| 160 | + .AsDuplicable(); |
| 161 | + AddInput("LnBias", |
| 162 | + "Bias is a 1-dimensional tensor of size " |
| 163 | + "H. Here, H represents the last dimension of its input tensor.") |
| 164 | + .AsDuplicable(); |
| 165 | + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); |
| 166 | + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); |
| 167 | + AddInput("CacheKV", "(optional) The cached KV for generation inference.") |
| 168 | + .AsDispensable() |
| 169 | + .AsDuplicable(); |
| 170 | + AddInput("TimeStep", |
| 171 | + "(optional, int) The time step for generation inference.") |
| 172 | + .AsDispensable(); |
| 173 | + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") |
| 174 | + .AsDispensable(); |
| 175 | + AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); |
| 176 | + AddInput("OutLinearBias", "The out_linear bias tensor.") |
| 177 | + .AsDispensable() |
| 178 | + .AsDuplicable(); |
| 179 | + |
| 180 | + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") |
| 181 | + .AsDuplicable(); |
| 182 | + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") |
| 183 | + .AsDuplicable(); |
| 184 | + AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op") |
| 185 | + .AsDuplicable(); |
| 186 | + AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op") |
| 187 | + .AsDispensable() |
| 188 | + .AsDuplicable(); |
| 189 | + AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op") |
| 190 | + .AsDuplicable(); |
| 191 | + AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") |
| 192 | + .AsDispensable() |
| 193 | + .AsDuplicable(); |
| 194 | + |
| 195 | + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") |
| 196 | + .AsDispensable() |
| 197 | + .AsDuplicable(); |
| 198 | + AddOutput("Out", "Result after multi ."); |
| 199 | + |
| 200 | + AddAttr<bool>("pre_layer_norm", |
| 201 | + "if true, the attention op uses pre_layer_norm architecure, " |
| 202 | + "else, uses post_layer_norm architecuture. " |
| 203 | + "[default true].") |
| 204 | + .SetDefault(true); |
| 205 | + AddAttr<float>("epsilon", |
| 206 | + "Constant for numerical stability [default 1e-5].") |
| 207 | + .SetDefault(1e-5) |
| 208 | + .AddCustomChecker([](const float &epsilon) { |
| 209 | + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true, |
| 210 | + platform::errors::InvalidArgument( |
| 211 | + "'epsilon' in Op(LayerNorm) should be between" |
| 212 | + "0.0 and 0.001, But received [%s].", |
| 213 | + epsilon)); |
| 214 | + }); |
| 215 | + |
| 216 | + AddAttr<float>("dropout_rate", "Probability of setting units to zero.") |
| 217 | + .SetDefault(.5f) |
| 218 | + .AddCustomChecker([](const float &drop_p) { |
| 219 | + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true, |
| 220 | + platform::errors::InvalidArgument( |
| 221 | + "'dropout_rate' must be between 0.0 and 1.0.")); |
| 222 | + }); |
| 223 | + |
| 224 | + AddAttr<bool>("dropout_is_test", |
| 225 | + "(bool, default false) Set to true for inference only, false " |
| 226 | + "for training. Some layers may run faster when this is true.") |
| 227 | + .SetDefault(false); |
| 228 | + AddAttr<std::string>( |
| 229 | + "dropout_implementation", |
| 230 | + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" |
| 231 | + "The meaning is the same as 'attn_dropout_implementation'.") |
| 232 | + .SetDefault("downgrade_in_infer") |
| 233 | + .AddCustomChecker([](const std::string &type) { |
| 234 | + PADDLE_ENFORCE_EQ( |
| 235 | + type == "downgrade_in_infer" || type == "upscale_in_train", true, |
| 236 | + platform::errors::InvalidArgument( |
| 237 | + "dropout_implementation can only be downgrade_in_infer or " |
| 238 | + "upscale_in_train")); |
| 239 | + }); |
| 240 | + AddAttr<std::string>("act_method", "act_method").SetDefault("gelu"); |
| 241 | + |
| 242 | + AddAttr<int>( |
| 243 | + "ring_id", |
| 244 | + "ring id for tensor model parallel. distributed training and inference") |
| 245 | + .SetDefault(-1); |
| 246 | + |
| 247 | + AddComment(R"DOC(fused multi transformer layers op)DOC"); |
| 248 | + } |
| 249 | +}; |
| 250 | + |
| 251 | +} // namespace operators |
| 252 | +} // namespace paddle |
| 253 | + |
| 254 | +namespace ops = paddle::operators; |
| 255 | +REGISTER_OPERATOR( |
| 256 | + fused_multi_transformer, ops::FusedMultiTransformerOp, |
| 257 | + ops::FusedMultiTransformerOpOpMaker, |
| 258 | + paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, |
| 259 | + paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); |
0 commit comments