|
| 1 | +#include "core/conversion/converters/converters.h" |
| 2 | +#include "core/util/prelude.h" |
| 3 | +#include "torch/torch.h" |
| 4 | + |
| 5 | +namespace trtorch { |
| 6 | +namespace core { |
| 7 | +namespace conversion { |
| 8 | +namespace converters { |
| 9 | +namespace impl { |
| 10 | +namespace { |
| 11 | + |
| 12 | +nvinfer1::ILayer* add_elementwise( |
| 13 | + ConversionCtx* ctx, |
| 14 | + nvinfer1::ElementWiseOperation op, |
| 15 | + nvinfer1::ITensor* self, |
| 16 | + nvinfer1::ITensor* other, |
| 17 | + const std::string& name) { |
| 18 | + // ensure self to have larger number of dimension |
| 19 | + bool swapSelfOther = false; |
| 20 | + if (self->getDimensions().nbDims < other->getDimensions().nbDims) { |
| 21 | + std::swap(self, other); |
| 22 | + swapSelfOther = true; |
| 23 | + } |
| 24 | + auto selfDim = util::toVec(self->getDimensions()); |
| 25 | + auto otherDim = util::toVec(other->getDimensions()); |
| 26 | + if (selfDim.size() != otherDim.size()) { |
| 27 | + // other is with dynamic shape, need to expand its dimension now and get its shape at runtime |
| 28 | + if (otherDim.end() != std::find(otherDim.begin(), otherDim.end(), -1)) { |
| 29 | + auto thOtherStaticShapeMask = torch::ones(selfDim.size(), torch::kInt32); |
| 30 | + auto thOtherDynamicShapeMask = torch::zeros(selfDim.size(), torch::kInt32); |
| 31 | + for (size_t start = selfDim.size() - otherDim.size(), idx = 0; idx < otherDim.size(); ++idx) { |
| 32 | + if (-1 != otherDim[idx]) { |
| 33 | + thOtherStaticShapeMask[start + idx] = otherDim[idx]; |
| 34 | + } else { |
| 35 | + thOtherStaticShapeMask[start + idx] = 0; |
| 36 | + thOtherDynamicShapeMask[start + idx] = 1; |
| 37 | + } |
| 38 | + } |
| 39 | + auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask); |
| 40 | + auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask); |
| 41 | + auto selfShape = ctx->net->addShape(*self)->getOutput(0); |
| 42 | + // size of dynamic dimension of other need to the same as that of corresponding dimension of self |
| 43 | + auto otherDynamicShape = |
| 44 | + ctx->net->addElementWise(*selfShape, *otherDynamicShapeMask, nvinfer1::ElementWiseOperation::kPROD) |
| 45 | + ->getOutput(0); |
| 46 | + auto targetOtherShape = |
| 47 | + ctx->net->addElementWise(*otherDynamicShape, *otherStaticShapeMask, nvinfer1::ElementWiseOperation::kSUM) |
| 48 | + ->getOutput(0); |
| 49 | + |
| 50 | + auto otherShuffle = ctx->net->addShuffle(*other); |
| 51 | + otherShuffle->setName(std::string("Reshape other tensor to have the same nDim as self for " + name).c_str()); |
| 52 | + otherShuffle->setInput(1, *targetOtherShape); |
| 53 | + other = otherShuffle->getOutput(0); |
| 54 | + } else { |
| 55 | + // other is with static shape, expand dimension to make tow tensor have the same number of dimension |
| 56 | + auto otherShuffle = ctx->net->addShuffle(*other); |
| 57 | + otherShuffle->setReshapeDimensions(util::toDimsPad(otherDim, selfDim.size())); |
| 58 | + other = otherShuffle->getOutput(0); |
| 59 | + } |
| 60 | + } |
| 61 | + if (swapSelfOther) { |
| 62 | + // swap back |
| 63 | + std::swap(self, other); |
| 64 | + swapSelfOther = false; |
| 65 | + } |
| 66 | + auto ele = ctx->net->addElementWise(*self, *other, op); |
| 67 | + ele->setName(name.c_str()); |
| 68 | + return ele; |
| 69 | +} |
| 70 | + |
| 71 | + |
| 72 | +auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({ |
| 73 | + R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, |
| 74 | + float eps, bool cudnn_enabled) -> (Tensor))SIG", |
| 75 | + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { |
| 76 | + |
| 77 | + auto input = args[0].ITensor(); // assumes non-static input Tensor |
| 78 | + auto orig_shape = input->getDimensions(); |
| 79 | + auto shape = util::toVec(orig_shape); |
| 80 | + |
| 81 | + /* Layer_Norm normalizes over last N dimensions. |
| 82 | + normalizaed_shape could be (C,H,W), (H,W), or (W). */ |
| 83 | + |
| 84 | + auto normalized_shape = args[1].unwrapToIntList(); |
| 85 | + auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape)); |
| 86 | + |
| 87 | + |
| 88 | + torch::Tensor gamma, beta; |
| 89 | + gamma = args[2].unwrapToTensor(); |
| 90 | + beta = args[3].unwrapToTensor(); |
| 91 | + |
| 92 | + // Remove batch dimension from input shape for expand_size, which will be used to create weights for addScaleNd later. |
| 93 | + auto expand_size = shape; |
| 94 | + expand_size.erase(expand_size.begin(), expand_size.begin() + 1); |
| 95 | + auto gamma_expand = gamma.expand(expand_size); |
| 96 | + auto beta_expand = beta.expand(expand_size); |
| 97 | + |
| 98 | + // Unwrap eps. |
| 99 | + auto eps = args[4].unwrapToDouble(); |
| 100 | + LOG_DEBUG("cudnn disregarded"); |
| 101 | + |
| 102 | + // Set up axis_ask for E[x]. |
| 103 | + uint32_t axis_mask = 0; |
| 104 | + for (size_t i = 0; i < normalized_shape_vec.size(); i++) { |
| 105 | + axis_mask |= 1 << (shape.size() - i - 1); |
| 106 | + } |
| 107 | + LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask)); |
| 108 | + |
| 109 | + // E[x] |
| 110 | + auto mean_layer_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, false); |
| 111 | + TRTORCH_CHECK(mean_layer_expected, "Unable to create mean_layer_expected from node: " << *n); |
| 112 | + mean_layer_expected->setName((util::node_info(n) + "_mean_expected").c_str()); |
| 113 | + auto mean_layer_expected_out = mean_layer_expected->getOutput(0); |
| 114 | + |
| 115 | + // Expand output of E[x] to the same shape as original input. |
| 116 | + c10::List<int64_t> repeats_expected; |
| 117 | + for (size_t i = 0; i < shape.size(); i++) { |
| 118 | + auto repeat = i > (shape.size() - normalized_shape_vec.size() - 1) ? shape[i] : 1; |
| 119 | + repeats_expected.push_back(repeat); |
| 120 | + } |
| 121 | + |
| 122 | + int repeats_expected_rank = repeats_expected.size(); // 4 |
| 123 | + auto mean_layer_expected_out_dims = mean_layer_expected_out->getDimensions(); // 1 |
| 124 | + auto num_expand_dims_expected = repeats_expected_rank - mean_layer_expected_out_dims.nbDims; // 3 |
| 125 | + |
| 126 | + if (num_expand_dims_expected > 0) { |
| 127 | + nvinfer1::Dims reshape_expected_dims; |
| 128 | + reshape_expected_dims.nbDims = repeats_expected.size(); |
| 129 | + for (int i = 0; i < num_expand_dims_expected; i++) { |
| 130 | + reshape_expected_dims.d[repeats_expected.size() - 1 - i ] = 1; |
| 131 | + } |
| 132 | + for (int i = 0; i < mean_layer_expected_out_dims.nbDims; i++) { |
| 133 | + reshape_expected_dims.d[i] = mean_layer_expected_out_dims.d[i]; |
| 134 | + } |
| 135 | + // Add a reshape layer to expand dims |
| 136 | + auto reshape_layer_expected = ctx->net->addShuffle(*mean_layer_expected_out); |
| 137 | + reshape_layer_expected->setReshapeDimensions(reshape_expected_dims); |
| 138 | + mean_layer_expected_out = reshape_layer_expected->getOutput(0); |
| 139 | + } |
| 140 | + |
| 141 | + for (int i = repeats_expected.size() - 1; i >= 0; --i) { |
| 142 | + std::vector<nvinfer1::ITensor*> tensors_vec; |
| 143 | + for (int j = 0; j < repeats_expected[i]; j++) { |
| 144 | + tensors_vec.push_back(mean_layer_expected_out); |
| 145 | + } |
| 146 | + auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); |
| 147 | + concat_layer->setAxis(i); |
| 148 | + mean_layer_expected_out = concat_layer->getOutput(0); |
| 149 | + } |
| 150 | + |
| 151 | + |
| 152 | + // X-E[x] |
| 153 | + auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_layer_expected_out, (util::node_info(n) + "_sub").c_str()); |
| 154 | + TRTORCH_CHECK(sub, "Unable to create Add layer from node: " << *n); |
| 155 | + sub->setName((util::node_info(n) + "_sub").c_str()); |
| 156 | + auto xsubmean = sub->getOutput(0); |
| 157 | + |
| 158 | + // Variance |
| 159 | + float pow_scalar = 2; |
| 160 | + auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar})); |
| 161 | + auto pow = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean, exponent, (util::node_info(n) +"_pow").c_str()); |
| 162 | + TRTORCH_CHECK(pow, "Unable to create Power layer from node: " << *n); |
| 163 | + pow->setName((util::node_info(n) +"_pow").c_str()); |
| 164 | + auto pow_out = pow->getOutput(0); |
| 165 | + |
| 166 | + auto mean_layer_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, false); |
| 167 | + TRTORCH_CHECK(mean_layer_var, "Unable to create mean_layer_var from node: " << *n); |
| 168 | + mean_layer_var->setName((util::node_info(n) + "_mean_var").c_str()); |
| 169 | + auto mean_layer_var_out = mean_layer_var->getOutput(0); |
| 170 | + |
| 171 | + // Expand output of mean_layer_var to the same shape as original input. |
| 172 | + c10::List<int64_t> repeats_var; |
| 173 | + for (size_t i = 0; i < shape.size(); i++) { |
| 174 | + auto repeat = i > (shape.size()-normalized_shape_vec.size()-1) ? shape[i] : 1; |
| 175 | + repeats_var.push_back(repeat); |
| 176 | + } |
| 177 | + |
| 178 | + int repeats_var_rank = repeats_var.size(); |
| 179 | + auto mean_layer_var_out_dims = mean_layer_var_out->getDimensions(); |
| 180 | + auto num_expand_dims_var = repeats_var_rank - mean_layer_var_out_dims.nbDims; |
| 181 | + |
| 182 | + |
| 183 | + if (num_expand_dims_var > 0) { |
| 184 | + nvinfer1::Dims reshape_dims_var; |
| 185 | + reshape_dims_var.nbDims = repeats_var.size(); |
| 186 | + for (int i = 0; i < num_expand_dims_var; i++) { |
| 187 | + reshape_dims_var.d[repeats_var.size() - 1 - i ] = 1; |
| 188 | + } |
| 189 | + for (int i = 0; i < mean_layer_var_out_dims.nbDims; i++) { |
| 190 | + reshape_dims_var.d[i] = mean_layer_var_out_dims.d[i]; |
| 191 | + } |
| 192 | + |
| 193 | + // Add a reshape layer to expand dims |
| 194 | + auto reshape_layer_var = ctx->net->addShuffle(*mean_layer_var_out); |
| 195 | + reshape_layer_var->setReshapeDimensions(reshape_dims_var); |
| 196 | + mean_layer_var_out = reshape_layer_var->getOutput(0); |
| 197 | + } |
| 198 | + |
| 199 | + |
| 200 | + for (int i = repeats_var.size() - 1; i >= 0; --i) { |
| 201 | + std::vector<nvinfer1::ITensor*> tensors_vec; |
| 202 | + for (int j = 0; j < repeats_var[i]; j++) { |
| 203 | + tensors_vec.push_back(mean_layer_var_out); |
| 204 | + } |
| 205 | + auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); |
| 206 | + concat_layer->setAxis(i); |
| 207 | + mean_layer_var_out = concat_layer->getOutput(0); |
| 208 | + } |
| 209 | + |
| 210 | + // add eps |
| 211 | + auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps})); |
| 212 | + auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, mean_layer_var_out, eps_tensor, (util::node_info(n)+"_add").c_str()); |
| 213 | + TRTORCH_CHECK(add, "Unable to create Add layer from node: " << *n); |
| 214 | + add->setName((util::node_info(n)+"_add").c_str()); |
| 215 | + auto add_out = add->getOutput(0); |
| 216 | + |
| 217 | + // add Unary layer for sqrt((var + eps)) |
| 218 | + auto unary = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT); |
| 219 | + TRTORCH_CHECK(unary, "Unable to create unary layer from node: " << *n); |
| 220 | + unary->setName((util::node_info(n)+"_unary_sqrt").c_str()); |
| 221 | + auto unary_out = unary->getOutput(0); |
| 222 | + |
| 223 | + |
| 224 | + // (x - E[x]) / sqrt((var + eps)) |
| 225 | + auto div= add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean, unary_out, (util::node_info(n)+"_div").c_str()); |
| 226 | + TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); |
| 227 | + div->setName((util::node_info(n)+"_div").c_str()); |
| 228 | + auto div_out = div->getOutput(0); |
| 229 | + |
| 230 | + // Set up gamma_weights and beta_weights from gamma_expand and beta_expand |
| 231 | + auto gamma_weights = Weights(ctx, gamma_expand); |
| 232 | + auto beta_weights = Weights(ctx, beta_expand); |
| 233 | + |
| 234 | + auto power = Weights(ctx, at::ones_like(gamma_expand)); |
| 235 | + auto scale_nd = ctx->net->addScaleNd(*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1); |
| 236 | + |
| 237 | + scale_nd->setName((util::node_info(n)+"_scale_nd").c_str()); |
| 238 | + auto scale_nd_out = scale_nd->getOutput(0); |
| 239 | + |
| 240 | + ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out); |
| 241 | + return true; |
| 242 | + }}); |
| 243 | + |
| 244 | +} // namespace |
| 245 | +} // namespace impl |
| 246 | +} // namespace converters |
| 247 | +} // namespace conversion |
| 248 | +} // namespace core |
| 249 | +} // namespace trtorch |
0 commit comments