|
| 1 | +/* |
| 2 | + * SPDX-License-Identifier: Apache-2.0 |
| 3 | + * |
| 4 | + */ |
| 5 | + |
| 6 | +#include "AttentionHelpers.hpp" |
| 7 | +#include "ImporterContext.hpp" |
| 8 | +#include "NvInfer.h" |
| 9 | +#include "ShapeTensor.hpp" |
| 10 | +#include "errorHelpers.hpp" |
| 11 | +#include "importerUtils.hpp" |
| 12 | +#include <cmath> |
| 13 | +#include <numeric> |
| 14 | +#include <string> |
| 15 | +#include <vector> |
| 16 | + |
| 17 | +namespace |
| 18 | +{ |
| 19 | +//! |
| 20 | +//! \brief Return true if `divident` is divisible by `divisor`. |
| 21 | +//! |
| 22 | +bool isDivisible(int64_t const divident, int64_t const divisor) |
| 23 | +{ |
| 24 | + return (divisor != 0) && ((divident % divisor) == 0); |
| 25 | +} |
| 26 | +} // namespace |
| 27 | + |
| 28 | +namespace onnx2trt |
| 29 | +{ |
| 30 | + |
| 31 | +//! |
| 32 | +//! \brief Reshape and return the Q, K, or V tensor from the input tensor. |
| 33 | +//! |
| 34 | +//! \param qkvInput The input tensor. This can either be a 4D tensor (batchSize, numHeads, sequenceLength, headSize) or |
| 35 | +//! a 3D tensor (batchSize, sequenceLength, hiddenSize=numHeads*headSize). If it is a 3D tensor, |
| 36 | +//! permute and reshape to the 4D shape before returning. Otherwise, return the input tensor. |
| 37 | +//! \param attrs The ONNX node attributes. |
| 38 | +//! \param ctx The importer context. |
| 39 | +//! \param isQ True if the input tensor is the Q tensor, false if it is the K or V tensor. |
| 40 | +//! \return nvinfer1::ITensor& The Q, K, or V tensor. |
| 41 | +//! |
| 42 | +nvinfer1::ITensor& reshapeQKVTensor( |
| 43 | + TensorOrWeights& qkvInput, OnnxAttrs const& attrs, ImporterContext* ctx, bool const isQ) |
| 44 | +{ |
| 45 | + if (qkvInput.shape().nbDims == 3) |
| 46 | + { |
| 47 | + // qkvInput is a 3D tensor (batchSize, sequenceLength, hiddenSize=numHeads * headSize). |
| 48 | + // Get relevant dimensions. |
| 49 | + int64_t const numHeadsValue |
| 50 | + = isQ ? attrs.get<int64_t>("q_num_heads", 0) : attrs.get<int64_t>("kv_num_heads", 0); |
| 51 | + ONNXTRT_CHECK(numHeadsValue != 0, |
| 52 | + "q_num_heads and kv_num_heads attributes are not specified, which are required for 3D Q/K/V tensors", |
| 53 | + ErrorCode::kINVALID_NODE); |
| 54 | + ShapeTensor numHeads = shapeVector(numHeadsValue); |
| 55 | + |
| 56 | + ShapeTensor hiddenSize = gather(ctx, shapeOf(qkvInput), shapeVector(2)); |
| 57 | + if (hiddenSize.allValuesKnown()) |
| 58 | + { |
| 59 | + // Perform static check for divisibility. |
| 60 | + ONNXTRT_CHECK(isDivisible(hiddenSize[0], numHeads[0]), |
| 61 | + "hidden_size must be divisible by num_heads. Received hidden_size=" << hiddenSize[0] |
| 62 | + << " and num_heads=" << numHeads, |
| 63 | + ErrorCode::kINVALID_NODE); |
| 64 | + } |
| 65 | + |
| 66 | + ShapeTensor headSize = floorDiv(ctx, hiddenSize, numHeads); |
| 67 | + |
| 68 | + // == Transform (batchSize, sequenceLength, hiddenSize) -> (batchSize, numHeads, sequenceLength, headSize) by == |
| 69 | + // 1. Reshape to (batchSize, sequenceLength, numHeads, headSize). |
| 70 | + // Use (0, 0, numHeads, headSize) as a shorthand to propagate `batchSize` and `sequenceLength` from the input |
| 71 | + // tensor without instantiating them. Set `zeroIsPlaceholder` to enable this shorthand. |
| 72 | + ShapeTensor newShape = concat(ctx, fillShapeVector(ctx, 0, shapeVector(2)), concat(ctx, numHeads, headSize)); |
| 73 | + nvinfer1::IShuffleLayer* shuffle |
| 74 | + = addShuffle(ctx, convertToTensor(qkvInput, ctx), newShape, /*zeroIsPlaceholder*/ true); |
| 75 | + |
| 76 | + // 2. Permute to (batchSize, numHeads, sequenceLength, headSize) |
| 77 | + shuffle->setSecondTranspose({0, 2, 1, 3}); |
| 78 | + |
| 79 | + return *N_CHECK(shuffle->getOutput(0)); |
| 80 | + } |
| 81 | + else |
| 82 | + { |
| 83 | + return convertToTensor(qkvInput, ctx); |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +//! |
| 88 | +//! \brief Scale the Q or K tensor by `sqrt(scale)`. |
| 89 | +//! |
| 90 | +//! `scale` is either provided as an attribute or set as the default value of `1/sqrt(headSize)`. `scale` is defined as |
| 91 | +//! `QK^T -> QK^T * scale`, but we apply `Q -> Q * sqrt(scale)` and `K -> K * sqrt(scale)` for numerical stability. |
| 92 | +//! |
| 93 | +//! \param qkTensor The Q or K tensor to scale. |
| 94 | +//! \param attrs The ONNX node attributes. |
| 95 | +//! \param ctx The importer context. |
| 96 | +//! \return nvinfer1::ITensor& The scaled Q or K tensor. |
| 97 | +//! |
| 98 | +nvinfer1::ITensor& scaleQKTensor(nvinfer1::ITensor& qkTensor, OnnxAttrs const& attrs, ImporterContext* ctx) |
| 99 | +{ |
| 100 | + nvinfer1::ITensor* sqrtScale = nullptr; |
| 101 | + |
| 102 | + if (attrs.count("scale")) |
| 103 | + { |
| 104 | + // Obtain the sqrt of scale as a constant (output of a constant layer). |
| 105 | + nvinfer1::IConstantLayer* constant = addConstantScalar( |
| 106 | + ctx, std::sqrt(attrs.get<float>("scale")), ::ONNX_NAMESPACE::TensorProto::FLOAT, {4, {1, 1, 1, 1}}); |
| 107 | + sqrtScale = castHelper(ctx, N_CHECK(constant)->getOutput(0), qkTensor.getType()); |
| 108 | + } |
| 109 | + else |
| 110 | + { |
| 111 | + ShapeTensor headSize = gather(ctx, shapeOf(qkTensor), shapeScalar(3)); |
| 112 | + nvinfer1::ITensor* headSizeF = castHelper(ctx, &headSize.tensor(ctx), qkTensor.getType()); |
| 113 | + |
| 114 | + // By default, scale := 1/sqrt(headSize) |
| 115 | + nvinfer1::ITensor* sqrtHeadSize = getUnaryResult(ctx, *headSizeF, nvinfer1::UnaryOperation::kSQRT); |
| 116 | + nvinfer1::ITensor* scale = getUnaryResult(ctx, *sqrtHeadSize, nvinfer1::UnaryOperation::kRECIP); |
| 117 | + |
| 118 | + sqrtScale = getUnaryResult(ctx, *scale, nvinfer1::UnaryOperation::kSQRT); |
| 119 | + sqrtScale = unsqueezeTensor(ctx, *sqrtScale, {0, 1, 2, 3}); |
| 120 | + } |
| 121 | + |
| 122 | + // Scale Q or K tensor by `sqrt(scale)`. |
| 123 | + return *getElementWiseResult(ctx, qkTensor, *sqrtScale, nvinfer1::ElementWiseOperation::kPROD); |
| 124 | +} |
| 125 | + |
| 126 | +nvinfer1::ITensor& convertToQTensor(TensorOrWeights& qInput, OnnxAttrs const& attrs, ImporterContext* ctx) |
| 127 | +{ |
| 128 | + return scaleQKTensor(reshapeQKVTensor(qInput, attrs, ctx, true), attrs, ctx); |
| 129 | +} |
| 130 | + |
| 131 | +nvinfer1::ITensor& convertToKTensor(TensorOrWeights& kInput, OnnxAttrs const& attrs, ImporterContext* ctx) |
| 132 | +{ |
| 133 | + return scaleQKTensor(reshapeQKVTensor(kInput, attrs, ctx, false), attrs, ctx); |
| 134 | +} |
| 135 | + |
| 136 | +nvinfer1::ITensor& convertToVTensor(TensorOrWeights& vInput, OnnxAttrs const& attrs, ImporterContext* ctx) |
| 137 | +{ |
| 138 | + return reshapeQKVTensor(vInput, attrs, ctx, false); |
| 139 | +} |
| 140 | + |
| 141 | +nvinfer1::ITensor& convertToMaskTensor(TensorOrWeights& maskInput, ImporterContext* ctx) |
| 142 | +{ |
| 143 | + ONNXTRT_CHECK(maskInput.shape().nbDims <= 4, |
| 144 | + "Attention masks should have rank leq 4. Got mask with rank " << maskInput.shape().nbDims << ".", |
| 145 | + ErrorCode::kINVALID_NODE); |
| 146 | + |
| 147 | + if (maskInput.shape().nbDims == 4) |
| 148 | + { |
| 149 | + // Mask has rank 4. Directly return the mask tensor. |
| 150 | + return convertToTensor(maskInput, ctx); |
| 151 | + } |
| 152 | + else |
| 153 | + { |
| 154 | + // Mask has rank less than 4. Reshape to rank 4 by prepending dimensions. |
| 155 | + int32_t const numDimsToPrepend = 4 - maskInput.shape().nbDims; |
| 156 | + std::vector<int32_t> unsqueezeAxes(numDimsToPrepend); |
| 157 | + std::iota(unsqueezeAxes.begin(), unsqueezeAxes.end(), 0); |
| 158 | + |
| 159 | + return *unsqueezeTensor(ctx, convertToTensor(maskInput, ctx), unsqueezeAxes); |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +nvinfer1::AttentionNormalizationOp parseNormalizationOp(OnnxAttrs const& attrs) |
| 164 | +{ |
| 165 | + std::string normalizationOp |
| 166 | + = attrs.get<std::string>("TRT_normalization_op", "softmax"); // Normalization op defaults to softmax. |
| 167 | + if (normalizationOp == "softmax") |
| 168 | + { |
| 169 | + return nvinfer1::AttentionNormalizationOp::kSOFTMAX; |
| 170 | + } |
| 171 | + else if (normalizationOp == "none") |
| 172 | + { |
| 173 | + return nvinfer1::AttentionNormalizationOp::kNONE; |
| 174 | + } |
| 175 | + else |
| 176 | + { |
| 177 | + ONNXTRT_CHECK(false, "Unsupported normalization op: " << normalizationOp, ErrorCode::kINVALID_NODE); |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +} // namespace onnx2trt |
0 commit comments