Skip to content

Commit 6685d83

Browse files
committed
convert com.microsoft to onnx RotaryEmbedding
1 parent 51022ca commit 6685d83

File tree

3 files changed

+103
-1
lines changed

3 files changed

+103
-1
lines changed

src/Dialect/ONNX/ONNXOps/NN/RotaryEmbedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ LogicalResult ONNXRotaryEmbeddingOp::verify() {
9090
*this->getOperation(), cosCache, lastIndex, cosCacheShape[lastIndex],
9191
std::to_string(rotaryEmbeddingDim / 2));
9292
lastIndex = sinCacheShape.size() - 1;
93-
if (sinCacheShape[lastIndex] == rotaryEmbeddingDim / 2)
93+
if (sinCacheShape[lastIndex] != rotaryEmbeddingDim / 2)
9494
return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError(
9595
*this->getOperation(), sinCache, lastIndex, sinCacheShape[lastIndex],
9696
std::to_string(rotaryEmbeddingDim / 2));

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3458,6 +3458,62 @@ struct MicrosoftGroupQueryAttention : public CustomOpToOnnxOps {
34583458
};
34593459
};
34603460

3461+
struct MicrosoftRotaryEmbedding : public CustomOpToOnnxOps {
3462+
MicrosoftRotaryEmbedding(MLIRContext *ctx, PatternBenefit b = 1)
3463+
: CustomOpToOnnxOps(ctx, MicrosoftDomainName, "RotaryEmbedding", b) {}
3464+
3465+
LogicalResult matchAndRewriteImpl(
3466+
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
3467+
3468+
const Location loc = customOp.getLoc();
3469+
const int64_t numIn = customOp.getNumOperands();
3470+
assert((numIn == 4) && "expects 4 inputs");
3471+
const int64_t numOut = customOp.getNumResults();
3472+
assert((numOut == 1) && "expects 1 outputs");
3473+
3474+
Value input = customOp.getOperand(0);
3475+
Value position_ids = customOp.getOperand(1);
3476+
Value cos_cache = customOp.getOperand(2);
3477+
Value sin_cache = customOp.getOperand(3);
3478+
3479+
if (customOp->hasAttrOfType<IntegerAttr>("is_packed_batching") &&
3480+
customOp->getAttrOfType<IntegerAttr>("is_packed_batching").getSInt() !=
3481+
0)
3482+
return rewriter.notifyMatchFailure(
3483+
customOp, "attribute 'is_packed_batching' not supported by "
3484+
"onnx.RotaryEmbedding");
3485+
if (customOp->hasAttrOfType<IntegerAttr>("scale") &&
3486+
customOp->getAttrOfType<FloatAttr>("scale").getValueAsDouble() != 1.0f)
3487+
return rewriter.notifyMatchFailure(
3488+
customOp, "attribute 'scale' not supported by onnx.RotaryEmbedding");
3489+
3490+
auto rotaryEmbedding =
3491+
rewriter.create<ONNXRotaryEmbeddingOp>(loc, customOp->getResultTypes(),
3492+
ValueRange{input, cos_cache, sin_cache, position_ids});
3493+
3494+
if (customOp->hasAttrOfType<IntegerAttr>("num_heads"))
3495+
rotaryEmbedding->setAttr(
3496+
"num_heads", customOp->getAttrOfType<IntegerAttr>("num_heads"));
3497+
3498+
if (customOp->hasAttrOfType<IntegerAttr>("interleaved"))
3499+
rotaryEmbedding->setAttr(
3500+
"interleaved", customOp->getAttrOfType<IntegerAttr>("interleaved"));
3501+
3502+
if (customOp->hasAttrOfType<IntegerAttr>("rotary_embedding_dim"))
3503+
rotaryEmbedding->setAttr("rotary_embedding_dim",
3504+
customOp->getAttrOfType<IntegerAttr>("rotary_embedding_dim"));
3505+
3506+
if (failed(verifyOpsErasingOnError({rotaryEmbedding}, rewriter))) {
3507+
return rewriter.notifyMatchFailure(
3508+
customOp, "Decomposition failed verification");
3509+
}
3510+
3511+
rewriter.replaceOp(customOp, rotaryEmbedding);
3512+
3513+
return success();
3514+
};
3515+
};
3516+
34613517
template <typename OpToCreate>
34623518
struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpToOnnxOps {
34633519
CustomOpMicrosoftToSingleOnnxOp(MLIRContext *context,
@@ -3946,6 +4002,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
39464002
patterns.insert<SimplifiedLayerNorm>(context);
39474003
patterns.insert<MicrosoftSkipSimplifiedLayerNorm>(context);
39484004
patterns.insert<MicrosoftGroupQueryAttention>(context);
4005+
patterns.insert<MicrosoftRotaryEmbedding>(context);
39494006
patterns.insert<DecomposeSlicePadPattern>(context);
39504007
patterns.insert<DecomposeScatterNDPattern>(context);
39514008
patterns.insert<SoftmaxCrossEntropyPattern>(context);

test/mlir/onnx/onnx_decompose_customop.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,3 +866,48 @@ func.func @gqa_with_scale_softcap_and_qk_output_2(
866866
// CHECK-SAME: : (tensor<1x128x3072xf32>, tensor<1x128x1536xf32>, tensor<1x128x1536xf32>, none, tensor<1x16x256x96xf32>, tensor<1x16x256x96xf32>) -> (tensor<1x128x3072xf32>, tensor<1x16x384x96xf32>, tensor<1x16x384x96xf32>, tensor<1x32x128x256xf32>)
867867
// CHECK: return %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : tensor<1x128x3072xf32>, tensor<1x16x384x96xf32>, tensor<1x16x384x96xf32>, tensor<1x32x128x256xf32>
868868
// CHECK: }
869+
870+
// -----
871+
872+
func.func @rotary_embedding_4d_interleaved_rotdim_16(%data: tensor<1x32x128x96xf32>, %pos_ids: tensor<1x128xi64>, %cos_cache: tensor<4096x8xf32>, %sin_cache: tensor<4096x8xf32>) -> tensor<1x32x128x96xf32> {
873+
%0 = "onnx.Custom"(%data, %pos_ids, %cos_cache, %sin_cache) {
874+
domain_name = "com.microsoft",
875+
function_name = "RotaryEmbedding",
876+
interleaved = 1 : si64,
877+
rotary_embedding_dim = 16 : si64
878+
}: (tensor<1x32x128x96xf32>, tensor<1x128xi64>, tensor<4096x8xf32>, tensor<4096x8xf32>) -> tensor<1x32x128x96xf32>
879+
return %0 : tensor<1x32x128x96xf32>
880+
}
881+
882+
// CHECK-LABEL: func.func @rotary_embedding_4d_interleaved_rotdim_16(
883+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x128x96xf32>,
884+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x128xi64>,
885+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4096x8xf32>,
886+
// CHECK-SAME: %[[VAL_3:.*]]: tensor<4096x8xf32>) -> tensor<1x32x128x96xf32> {
887+
// CHECK: %[[VAL_4:.*]] = "onnx.RotaryEmbedding"(%[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_1]])
888+
// CHECK-SAME: {interleaved = 1 : si64, rotary_embedding_dim = 16 : si64}
889+
// CHECK-SAME: : (tensor<1x32x128x96xf32>, tensor<4096x8xf32>, tensor<4096x8xf32>, tensor<1x128xi64>) -> tensor<1x32x128x96xf32>
890+
// CHECK: return %[[VAL_4]] : tensor<1x32x128x96xf32>
891+
// CHECK: }
892+
893+
// -----
894+
895+
func.func @test_rotary_embedding_3d(%data: tensor<1x128x3072xf32>, %pos_ids: tensor<1x128xi64>, %cos_cache: tensor<4096x48xf32>, %sin_cache: tensor<4096x48xf32>) -> tensor<1x128x3072xf32> {
896+
%0 = "onnx.Custom"(%data, %pos_ids, %cos_cache, %sin_cache) {
897+
domain_name = "com.microsoft",
898+
function_name = "RotaryEmbedding",
899+
num_heads = 32: si64
900+
} : (tensor<1x128x3072xf32>, tensor<1x128xi64>, tensor<4096x48xf32>, tensor<4096x48xf32>) -> tensor<1x128x3072xf32>
901+
return %0 : tensor<1x128x3072xf32>
902+
}
903+
904+
// CHECK-LABEL: func.func @test_rotary_embedding_3d(
905+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x128x3072xf32>,
906+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x128xi64>,
907+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4096x48xf32>,
908+
// CHECK-SAME: %[[VAL_3:.*]]: tensor<4096x48xf32>) -> tensor<1x128x3072xf32> {
909+
// CHECK: %[[VAL_4:.*]] = "onnx.RotaryEmbedding"(%[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_1]])
910+
// CHECK-SAME: {interleaved = 0 : si64, num_heads = 32 : si64, rotary_embedding_dim = 0 : si64}
911+
// CHECK-SAME: : (tensor<1x128x3072xf32>, tensor<4096x48xf32>, tensor<4096x48xf32>, tensor<1x128xi64>) -> tensor<1x128x3072xf32>
912+
// CHECK: return %[[VAL_4]] : tensor<1x128x3072xf32>
913+
// CHECK: }

0 commit comments

Comments
 (0)