Skip to content

Commit 781c3e3

Browse files
committed
address review comments
1 parent 24168c6 commit 781c3e3

File tree

4 files changed

+91
-57
lines changed

4 files changed

+91
-57
lines changed

src/Dialect/ONNX/ONNXOps/NN/Attention.cpp

Lines changed: 44 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ template <>
2020
LogicalResult ONNXAttentionOpShapeHelper::computeShape() {
2121
auto attentionOp = cast<ONNXAttentionOp>(op);
2222

23-
int64_t rank = createIE->getShapedTypeRank(attentionOp.getQ());
23+
const int64_t rank = createIE->getShapedTypeRank(attentionOp.getQ());
24+
if (rank != 3 && rank != 4)
25+
return failure();
26+
2427
DimsExpr qShape;
2528
createIE->getShapeAsDims(attentionOp.getQ(), qShape);
2629
DimsExpr kShape;
@@ -31,29 +34,30 @@ LogicalResult ONNXAttentionOpShapeHelper::computeShape() {
3134
auto qNumHeads = attentionOp.getQNumHeads();
3235
auto kvNumHeads = attentionOp.getKvNumHeads();
3336

34-
if (rank == 4) {
35-
DimsExpr outputDims = qShape;
36-
outputDims[3] = vShape[3];
37-
setOutputDims(outputDims, 0);
38-
} else if (rank == 3) {
39-
assert(qNumHeads && kvNumHeads &&
40-
"*_num_heads attributes must be present with 3D inputs");
41-
DimsExpr outputDims = qShape;
42-
outputDims[2] = LitIE(*qNumHeads * (vShape[2].getLiteral() / *kvNumHeads));
43-
setOutputDims(outputDims, 0);
44-
} else {
45-
return failure();
46-
}
37+
auto normalizeInputTo4D = [](DimsExpr inputShape,
38+
std::optional<int64_t> numHeads) -> DimsExpr {
39+
DimsExpr shape4D = inputShape;
40+
if (inputShape.size() == 4)
41+
return shape4D;
4742

48-
// Need past_key/value inputs to infer shapes for present_key/value outputs
49-
if (attentionOp->getNumOperands() < 6)
50-
return success();
43+
assert(numHeads && "*_num_heads attributes must be present with 3D inputs");
44+
shape4D.insert(shape4D.begin() + 1, LitIE(*numHeads));
45+
shape4D[3] = shape4D[3].floorDiv(shape4D[1]);
5146

52-
if (isNoneValue(attentionOp.getPastKey()) ||
53-
isNoneValue(attentionOp.getPastValue()) ||
54-
isNoneValue(attentionOp.getPresentKey()) ||
55-
isNoneValue(attentionOp.getPresentValue()))
56-
return success();
47+
return shape4D;
48+
};
49+
50+
DimsExpr qShape4D = normalizeInputTo4D(qShape, qNumHeads);
51+
DimsExpr kShape4D = normalizeInputTo4D(kShape, kvNumHeads);
52+
DimsExpr vShape4D = normalizeInputTo4D(vShape, kvNumHeads);
53+
54+
DimsExpr outputDims = qShape;
55+
if (rank == 4) {
56+
outputDims[3] = vShape4D[3];
57+
} else /*if (rank == 3)*/ {
58+
outputDims[2] = qShape4D[1] * vShape4D[3];
59+
}
60+
setOutputDims(outputDims, 0);
5761

5862
if (!hasShapeAndRank(attentionOp.getPastKey()) ||
5963
!hasShapeAndRank(attentionOp.getPastValue()))
@@ -67,21 +71,19 @@ LogicalResult ONNXAttentionOpShapeHelper::computeShape() {
6771
if (pastKShape.size() != 4 || pastVShape.size() != 4)
6872
return failure();
6973

70-
auto totalSeqLen = pastKShape[2] + kShape[2];
74+
auto totalSeqLen = pastKShape[2] + kShape4D[2];
7175

72-
DimsExpr presentKeyDims = kShape;
76+
DimsExpr presentKeyDims = kShape4D;
7377
presentKeyDims[2] = totalSeqLen;
7478
setOutputDims(presentKeyDims, 1);
7579

76-
DimsExpr presentValueDims = vShape;
80+
DimsExpr presentValueDims = vShape4D;
7781
presentValueDims[2] = totalSeqLen;
7882
setOutputDims(presentValueDims, 2);
7983

80-
if (attentionOp.getQkMatmulOutputMode()) {
81-
DimsExpr qkOutputDims = qShape;
82-
qkOutputDims[3] = totalSeqLen;
83-
setOutputDims(presentValueDims, 3);
84-
}
84+
DimsExpr qkOutputDims = qShape4D;
85+
qkOutputDims[3] = totalSeqLen;
86+
setOutputDims(presentValueDims, 3);
8587

8688
return success();
8789
}
@@ -93,25 +95,16 @@ LogicalResult ONNXAttentionOpShapeHelper::computeShape() {
9395
//===----------------------------------------------------------------------===//
9496

9597
LogicalResult ONNXAttentionOp::verify() {
96-
const int64_t numIn = this->getNumOperands();
97-
const int64_t numOut = this->getNumResults();
98-
9998
// If presentK and presentV are outputs, then we must pass pastK and pastV as
10099
// inputs
101-
if (numOut >= 3) {
102-
Value presentK = this->getResult(1);
103-
Value presentV = this->getResult(2);
104-
if (!isNoneValue(presentK) || !isNoneValue(presentV)) {
105-
if (numIn < 6)
106-
return emitOpError("inputs 'pastK' and 'pastV' are needed for outputs "
107-
"'presentK' and 'presentV'");
108-
109-
Value pastK = this->getOperand(4);
110-
Value pastV = this->getOperand(5);
111-
if (isNoneValue(pastK) || isNoneValue(pastV))
112-
return emitOpError("inputs 'pastK' and 'pastV' are needed for outputs "
113-
"'presentK' and 'presentV'");
114-
}
100+
Value presentK = this->getResult(1);
101+
Value presentV = this->getResult(2);
102+
if (!isNoneValue(presentK) || !isNoneValue(presentV)) {
103+
Value pastK = this->getOperand(4);
104+
Value pastV = this->getOperand(5);
105+
if (isNoneValue(pastK) || isNoneValue(pastV))
106+
return emitOpError("inputs 'pastK' and 'pastV' are needed for outputs "
107+
"'presentK' and 'presentV'");
115108
}
116109

117110
ONNXAttentionOpAdaptor adaptor(*this);
@@ -120,7 +113,7 @@ LogicalResult ONNXAttentionOp::verify() {
120113
if (!hasShapeAndRank(q))
121114
return success(); // Won't be able to do any more checking at this stage.
122115

123-
auto qType = mlir::cast<ShapedType>(q.getType());
116+
auto qType = cast<ShapedType>(q.getType());
124117
int64_t qRank = qType.getShape().size();
125118
if (qRank != 3 && qRank != 4)
126119
return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError(
@@ -137,13 +130,13 @@ LogicalResult ONNXAttentionOp::verify() {
137130
if (!hasShapeAndRank(k) || !hasShapeAndRank(v))
138131
return success(); // Won't be able to do any more checking at this stage.
139132

140-
auto kType = mlir::cast<ShapedType>(k.getType());
133+
auto kType = cast<ShapedType>(k.getType());
141134
int64_t kRank = kType.getShape().size();
142135
if (kRank != 3 && kRank != 4)
143136
return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError(
144137
*this->getOperation(), k, kRank, "3 or 4");
145138

146-
auto vType = mlir::cast<ShapedType>(v.getType());
139+
auto vType = cast<ShapedType>(v.getType());
147140
int64_t vRank = vType.getShape().size();
148141
if (vRank != 3 && vRank != 4)
149142
return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError(
@@ -195,10 +188,9 @@ LogicalResult ONNXAttentionOp::inferShapes(
195188
if (!hasShapeAndRank(this->getOperand(i)))
196189
return success();
197190

198-
Type elementType = mlir::cast<ShapedType>(getQ().getType()).getElementType();
191+
Type elementType = getElementTypeOrSelf(getQ().getType());
199192
ONNXAttentionOpShapeHelper shapeHelper(getOperation(), {});
200193
return shapeHelper.computeShapeAndUpdateType(elementType);
201-
return success();
202194
}
203195

204196
//===----------------------------------------------------------------------===//

src/Dialect/ONNX/ONNXOps/NN/RotaryEmbedding.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ LogicalResult ONNXRotaryEmbeddingOp::verify() {
3535
return success(); // Won't be able to do any checking at this stage.
3636

3737
auto inputType = mlir::cast<ShapedType>(input.getType());
38-
int64_t inputRank = inputType.getRank();
38+
const int64_t inputRank = inputType.getRank();
3939

4040
if (inputRank != 3 && inputRank != 4)
4141
return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError(
@@ -46,13 +46,19 @@ LogicalResult ONNXRotaryEmbeddingOp::verify() {
4646
return emitOpError(
4747
"attribute 'num_heads' must be provided when input is a 3D tensor.");
4848

49-
// Check hidden_size divisible by num_heads
5049
if (inputType.hasStaticShape()) {
5150
auto inputShape = inputType.getShape();
52-
if (inputRank == 3 && numHeads && inputShape[2] % *numHeads != 0)
51+
// Check head_size is even
52+
if (inputRank == 4 && inputShape[3] % 2 != 0)
53+
return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError(
54+
*this->getOperation(), input, 3, inputShape[3], "even");
55+
56+
// Check hidden_size divisible by num_heads and resulting head_size is
57+
// even (i.e. hidden_size % (num_heads * 2) == 0)
58+
if (inputRank == 3 && numHeads && inputShape[2] % (*numHeads * 2) != 0)
5359
return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError(
5460
*this->getOperation(), input, 2, inputShape[2],
55-
"divisible by " + std::to_string(*numHeads));
61+
"divisible by " + std::to_string(*numHeads) + " * 2");
5662
}
5763

5864
Value cosCache = adaptor.getCosCache();
@@ -103,7 +109,7 @@ LogicalResult ONNXRotaryEmbeddingOp::inferShapes(
103109
if (!hasShapeAndRank(getOperation()->getOperand(0)))
104110
return success();
105111

106-
Type elementType = mlir::cast<ShapedType>(getX().getType()).getElementType();
112+
Type elementType = getElementTypeOrSelf(getX().getType());
107113
ONNXRotaryEmbeddingOpShapeHelper shapeHelper(getOperation(), {});
108114
return shapeHelper.computeShapeAndUpdateType(elementType);
109115
}

test/mlir/onnx/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,24 @@ func.func @test_rotary_embedding_bad_dtype(%data: tensor<1x128x3072xi64>, %cos_c
10071007

10081008
// -----
10091009

1010+
func.func @test_rotary_embedding_4d_odd_head_size(%data: tensor<1x32x128x95xf32>, %cos_cache: tensor<4096x48xf32>, %sin_cache: tensor<4096x48xf32>) -> tensor<*xf32> {
1011+
%pos_ids = "onnx.NoValue"() {value} : () -> none
1012+
// expected-error @+1 {{onnx.RotaryEmbedding: operand '<block argument> of type 'tensor<1x32x128x95xf32>' at index: 0' has dimension at index 3 with value 95, value should be even}}
1013+
%0 = "onnx.RotaryEmbedding"(%data, %cos_cache, %sin_cache, %pos_ids) {num_heads = 32: si64} : (tensor<1x32x128x95xf32>, tensor<4096x48xf32>, tensor<4096x48xf32>, none) -> tensor<*xf32>
1014+
return %0 : tensor<*xf32>
1015+
}
1016+
1017+
// -----
1018+
1019+
func.func @test_rotary_embedding_3d_odd_head_size(%data: tensor<1x128x3040xf32>, %cos_cache: tensor<4096x48xf32>, %sin_cache: tensor<4096x48xf32>) -> tensor<*xf32> {
1020+
%pos_ids = "onnx.NoValue"() {value} : () -> none
1021+
// expected-error @+1 {{onnx.RotaryEmbedding: operand '<block argument> of type 'tensor<1x128x3040xf32>' at index: 0' has dimension at index 2 with value 3040, value should be divisible by 32 * 2}}
1022+
%0 = "onnx.RotaryEmbedding"(%data, %cos_cache, %sin_cache, %pos_ids) {num_heads = 32: si64} : (tensor<1x128x3040xf32>, tensor<4096x48xf32>, tensor<4096x48xf32>, none) -> tensor<*xf32>
1023+
return %0 : tensor<*xf32>
1024+
}
1025+
1026+
// -----
1027+
10101028
func.func @test_rotary_embedding_bad_embedding_dim(%data: tensor<1x32x128x96xf32>, %cos_cache: tensor<4096x48xf32>, %sin_cache: tensor<4096x48xf32>) -> tensor<*xf32> {
10111029
%pos_ids = "onnx.NoValue"() {value} : () -> none
10121030
// expected-error @+1 {{onnx.RotaryEmbedding: operand '<block argument> of type 'tensor<4096x48xf32>' at index: 1' has dimension at index 1 with value 48, value should be 50}}

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4536,3 +4536,21 @@ func.func @test_attention_4d_qk_output(%q: tensor<1x32x128x96xf32>, %k: tensor<1
45364536
// CHECK-LABEL: func.func @test_attention_4d_qk_output
45374537
// CHECK: "onnx.Attention"
45384538
// CHECK-SAME: (tensor<1x32x128x96xf32>, tensor<1x16x128x96xf32>, tensor<1x16x128x48xf32>, none, tensor<1x16x256x96xf32>, tensor<1x16x256x48xf32>) -> (tensor<1x32x128x48xf32>, tensor<1x16x384x96xf32>, tensor<1x16x384x48xf32>, tensor<1x16x384x48xf32>)
4539+
4540+
func.func @test_attention_3d_inputs_4d_present_kv(%q: tensor<1x128x3072xf32>, %k: tensor<1x128x1536xf32>, %v: tensor<1x128x768xf32>, %past_k: tensor<1x16x256x96xf32>, %past_v: tensor<1x16x256x48xf32>) -> tensor<*xf32> {
4541+
%none = "onnx.NoValue"() {value} : () -> none
4542+
%out, %present_k, %present_v, %qk_out = "onnx.Attention"(%q, %k, %v, %none, %past_k, %past_v) {q_num_heads = 32: si64, kv_num_heads = 16: si64} : (tensor<1x128x3072xf32>, tensor<1x128x1536xf32>, tensor<1x128x768xf32>, none, tensor<1x16x256x96xf32>, tensor<1x16x256x48xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, none)
4543+
return %out : tensor<*xf32>
4544+
}
4545+
// CHECK-LABEL: func.func @test_attention_3d_inputs_4d_present_kv
4546+
// CHECK: "onnx.Attention"
4547+
// CHECK-SAME: (tensor<1x128x3072xf32>, tensor<1x128x1536xf32>, tensor<1x128x768xf32>, none, tensor<1x16x256x96xf32>, tensor<1x16x256x48xf32>) -> (tensor<1x128x1536xf32>, tensor<1x16x384x96xf32>, tensor<1x16x384x48xf32>, none)
4548+
4549+
func.func @test_attention_3d_q_4d_kv(%q: tensor<1x128x3072xf32>, %k: tensor<1x16x128x96xf32>, %v: tensor<1x16x128x48xf32>) -> tensor<*xf32> {
4550+
%none = "onnx.NoValue"() {value} : () -> none
4551+
%out, %present_k, %present_v, %qk_out = "onnx.Attention"(%q, %k, %v, %none, %none, %none) {q_num_heads = 32: si64} : (tensor<1x128x3072xf32>, tensor<1x16x128x96xf32>, tensor<1x16x128x48xf32>, none, none, none) -> (tensor<*xf32>, none, none, none)
4552+
return %out : tensor<*xf32>
4553+
}
4554+
// CHECK-LABEL: func.func @test_attention_3d_q_4d_kv
4555+
// CHECK: "onnx.Attention"
4556+
// CHECK-SAME: (tensor<1x128x3072xf32>, tensor<1x16x128x96xf32>, tensor<1x16x128x48xf32>, none, none, none) -> (tensor<1x128x1536xf32>

0 commit comments

Comments
 (0)