@@ -20,7 +20,10 @@ template <>
2020LogicalResult ONNXAttentionOpShapeHelper::computeShape () {
2121 auto attentionOp = cast<ONNXAttentionOp>(op);
2222
23- int64_t rank = createIE->getShapedTypeRank (attentionOp.getQ ());
23+ const int64_t rank = createIE->getShapedTypeRank (attentionOp.getQ ());
24+ if (rank != 3 && rank != 4 )
25+ return failure ();
26+
2427 DimsExpr qShape;
2528 createIE->getShapeAsDims (attentionOp.getQ (), qShape);
2629 DimsExpr kShape ;
@@ -31,29 +34,30 @@ LogicalResult ONNXAttentionOpShapeHelper::computeShape() {
3134 auto qNumHeads = attentionOp.getQNumHeads ();
3235 auto kvNumHeads = attentionOp.getKvNumHeads ();
3336
34- if (rank == 4 ) {
35- DimsExpr outputDims = qShape;
36- outputDims[3 ] = vShape[3 ];
37- setOutputDims (outputDims, 0 );
38- } else if (rank == 3 ) {
39- assert (qNumHeads && kvNumHeads &&
40- " *_num_heads attributes must be present with 3D inputs" );
41- DimsExpr outputDims = qShape;
42- outputDims[2 ] = LitIE (*qNumHeads * (vShape[2 ].getLiteral () / *kvNumHeads));
43- setOutputDims (outputDims, 0 );
44- } else {
45- return failure ();
46- }
37+ auto normalizeInputTo4D = [](DimsExpr inputShape,
38+ std::optional<int64_t > numHeads) -> DimsExpr {
39+ DimsExpr shape4D = inputShape;
40+ if (inputShape.size () == 4 )
41+ return shape4D;
4742
48- // Need past_key/value inputs to infer shapes for present_key/value outputs
49- if (attentionOp-> getNumOperands () < 6 )
50- return success ( );
43+ assert (numHeads && " *_num_heads attributes must be present with 3D inputs " );
44+ shape4D. insert (shape4D. begin () + 1 , LitIE (*numHeads));
45+ shape4D[ 3 ] = shape4D[ 3 ]. floorDiv (shape4D[ 1 ] );
5146
52- if (isNoneValue (attentionOp.getPastKey ()) ||
53- isNoneValue (attentionOp.getPastValue ()) ||
54- isNoneValue (attentionOp.getPresentKey ()) ||
55- isNoneValue (attentionOp.getPresentValue ()))
56- return success ();
47+ return shape4D;
48+ };
49+
50+ DimsExpr qShape4D = normalizeInputTo4D (qShape, qNumHeads);
51+ DimsExpr kShape4D = normalizeInputTo4D (kShape , kvNumHeads);
52+ DimsExpr vShape4D = normalizeInputTo4D (vShape, kvNumHeads);
53+
54+ DimsExpr outputDims = qShape;
55+ if (rank == 4 ) {
56+ outputDims[3 ] = vShape4D[3 ];
57+ } else /* if (rank == 3)*/ {
58+ outputDims[2 ] = qShape4D[1 ] * vShape4D[3 ];
59+ }
60+ setOutputDims (outputDims, 0 );
5761
5862 if (!hasShapeAndRank (attentionOp.getPastKey ()) ||
5963 !hasShapeAndRank (attentionOp.getPastValue ()))
@@ -67,21 +71,19 @@ LogicalResult ONNXAttentionOpShapeHelper::computeShape() {
6771 if (pastKShape.size () != 4 || pastVShape.size () != 4 )
6872 return failure ();
6973
70- auto totalSeqLen = pastKShape[2 ] + kShape [2 ];
74+ auto totalSeqLen = pastKShape[2 ] + kShape4D [2 ];
7175
72- DimsExpr presentKeyDims = kShape ;
76+ DimsExpr presentKeyDims = kShape4D ;
7377 presentKeyDims[2 ] = totalSeqLen;
7478 setOutputDims (presentKeyDims, 1 );
7579
76- DimsExpr presentValueDims = vShape ;
80+ DimsExpr presentValueDims = vShape4D ;
7781 presentValueDims[2 ] = totalSeqLen;
7882 setOutputDims (presentValueDims, 2 );
7983
80- if (attentionOp.getQkMatmulOutputMode ()) {
81- DimsExpr qkOutputDims = qShape;
82- qkOutputDims[3 ] = totalSeqLen;
83- setOutputDims (presentValueDims, 3 );
84- }
84+ DimsExpr qkOutputDims = qShape4D;
85+ qkOutputDims[3 ] = totalSeqLen;
86+ setOutputDims (presentValueDims, 3 );
8587
8688 return success ();
8789}
@@ -93,25 +95,16 @@ LogicalResult ONNXAttentionOpShapeHelper::computeShape() {
9395// ===----------------------------------------------------------------------===//
9496
9597LogicalResult ONNXAttentionOp::verify () {
96- const int64_t numIn = this ->getNumOperands ();
97- const int64_t numOut = this ->getNumResults ();
98-
9998 // If presentK and presentV are outputs, then we must pass pastK and pastV as
10099 // inputs
101- if (numOut >= 3 ) {
102- Value presentK = this ->getResult (1 );
103- Value presentV = this ->getResult (2 );
104- if (!isNoneValue (presentK) || !isNoneValue (presentV)) {
105- if (numIn < 6 )
106- return emitOpError (" inputs 'pastK' and 'pastV' are needed for outputs "
107- " 'presentK' and 'presentV'" );
108-
109- Value pastK = this ->getOperand (4 );
110- Value pastV = this ->getOperand (5 );
111- if (isNoneValue (pastK) || isNoneValue (pastV))
112- return emitOpError (" inputs 'pastK' and 'pastV' are needed for outputs "
113- " 'presentK' and 'presentV'" );
114- }
100+ Value presentK = this ->getResult (1 );
101+ Value presentV = this ->getResult (2 );
102+ if (!isNoneValue (presentK) || !isNoneValue (presentV)) {
103+ Value pastK = this ->getOperand (4 );
104+ Value pastV = this ->getOperand (5 );
105+ if (isNoneValue (pastK) || isNoneValue (pastV))
106+ return emitOpError (" inputs 'pastK' and 'pastV' are needed for outputs "
107+ " 'presentK' and 'presentV'" );
115108 }
116109
117110 ONNXAttentionOpAdaptor adaptor (*this );
@@ -120,7 +113,7 @@ LogicalResult ONNXAttentionOp::verify() {
120113 if (!hasShapeAndRank (q))
121114 return success (); // Won't be able to do any more checking at this stage.
122115
123- auto qType = mlir:: cast<ShapedType>(q.getType ());
116+ auto qType = cast<ShapedType>(q.getType ());
124117 int64_t qRank = qType.getShape ().size ();
125118 if (qRank != 3 && qRank != 4 )
126119 return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError (
@@ -137,13 +130,13 @@ LogicalResult ONNXAttentionOp::verify() {
137130 if (!hasShapeAndRank (k) || !hasShapeAndRank (v))
138131 return success (); // Won't be able to do any more checking at this stage.
139132
140- auto kType = mlir:: cast<ShapedType>(k.getType ());
133+ auto kType = cast<ShapedType>(k.getType ());
141134 int64_t kRank = kType .getShape ().size ();
142135 if (kRank != 3 && kRank != 4 )
143136 return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError (
144137 *this ->getOperation (), k, kRank , " 3 or 4" );
145138
146- auto vType = mlir:: cast<ShapedType>(v.getType ());
139+ auto vType = cast<ShapedType>(v.getType ());
147140 int64_t vRank = vType.getShape ().size ();
148141 if (vRank != 3 && vRank != 4 )
149142 return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError (
@@ -195,10 +188,9 @@ LogicalResult ONNXAttentionOp::inferShapes(
195188 if (!hasShapeAndRank (this ->getOperand (i)))
196189 return success ();
197190
198- Type elementType = mlir::cast<ShapedType> (getQ ().getType ()). getElementType ( );
191+ Type elementType = getElementTypeOrSelf (getQ ().getType ());
199192 ONNXAttentionOpShapeHelper shapeHelper (getOperation (), {});
200193 return shapeHelper.computeShapeAndUpdateType (elementType);
201- return success ();
202194}
203195
204196// ===----------------------------------------------------------------------===//
0 commit comments