Skip to content

Commit adc06c8

Browse files
[ONNX][TORCH] Add Onnx->Linalg lowering for RotaryEmbedding Op (#4002)
This commit adds the Onnx->Linalg lowering for Onnx's RotaryEmbedding op (ref: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftrotaryembedding) by registering a customized torch op named `OnnxVariantAtenRotaryEmbeddingOp`. This is done so that the Onnx's RotaryEmbedding op can be lowered to this op and this op can be lowered from Torch->Linalg. The lowering has been adopted from the OnnxRuntime. Files for references: 1.) https://github.com/microsoft/onnxruntime/blob/e1e3f623f61816008e79dddc91a51ffe7f0ff5cf/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc#L47-L93 2.) https://github.com/microsoft/onnxruntime/blob/94c69f55d480cb4a8dcbc161d29ef3acca9392a7/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h --------- Signed-off-by: Vivek Khandelwal <[email protected]> Co-authored-by: zjgarvey <[email protected]>
1 parent 0c41119 commit adc06c8

File tree

13 files changed

+604
-0
lines changed

13 files changed

+604
-0
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ class OnnxCustomOpConversionPattern
472472

473473
// Patterns are split into chunks to speed compile time and reduce some
474474
// contention on the same source files.
475+
void populateComMicrosoftDomain(OnnxCustomOpConversionPattern &patterns);
475476
void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns);
476477
void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns);
477478
void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns);

include/torch-mlir/Dialect/Torch/IR/TorchOps.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,4 +1410,36 @@ def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> {
14101410
let hasVerifier = 1;
14111411
}
14121412

1413+
// This op is corresponding to the Onnx's RotaryEmbedding operator.
1414+
// Ref: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftrotaryembedding
1415+
def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [
1416+
AllowsTypeRefinement,
1417+
HasValueSemantics,
1418+
ReadOnly
1419+
]> {
1420+
let summary = "`rotary_embedding op : (Tensor, Tensor, Tensor, Tensor, int, int, int, int, float) -> (Tensor)`";
1421+
let description = [{
1422+
The `torch.onnx.rotary_embedding` operation is an op which is used
1423+
specifically for supporting the Onnx's Rotary Embedding op. The
1424+
reason for this is that the Onnx ops can't be directly lowered to
1425+
Linalg and we have to map them to a legal Torch Dialect op, hence
1426+
this op is used for that purpose.
1427+
}];
1428+
let arguments = (ins
1429+
AnyTorchTensorType:$input,
1430+
AnyTorchTensorType:$position_ids,
1431+
AnyTorchTensorType:$cos_cache,
1432+
AnyTorchTensorType:$sin_cache,
1433+
Torch_IntType:$interleaved,
1434+
Torch_IntType:$is_packed_batching,
1435+
Torch_IntType:$num_heads,
1436+
Torch_IntType:$rotary_embedding_dim,
1437+
Torch_FloatType:$scale
1438+
);
1439+
let results = (outs
1440+
AnyTorchTensorType:$result
1441+
);
1442+
let hasCustomAssemblyFormat = 1;
1443+
}
1444+
14131445
#endif // TORCH_OPS

lib/Conversion/TorchOnnxToTorch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch
2+
ComMicrosoftDomain.cpp
23
DefaultDomainAtoF.cpp
34
DefaultDomainGtoP.cpp
45
DefaultDomainQtoZ.cpp
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===------------------------------------------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
11+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
12+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
13+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
14+
#include <numeric>
15+
16+
using namespace mlir;
17+
using namespace mlir::torch;
18+
using namespace mlir::torch::onnx_c;
19+
20+
void mlir::torch::onnx_c::populateComMicrosoftDomain(
21+
OnnxCustomOpConversionPattern &patterns) {
22+
patterns.onOp(
23+
"RotaryEmbedding", 1,
24+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
25+
Location loc = binder.getLoc();
26+
int64_t interleaved, isPackedBatching, numHeads, rotaryEmbeddingDim;
27+
float scale;
28+
Value input, positionIds, cosCache, sinCache;
29+
if (binder.tensorOperandAtIndex(input, 0) ||
30+
binder.tensorOperandAtIndex(positionIds, 1) ||
31+
binder.tensorOperandAtIndex(cosCache, 2) ||
32+
binder.tensorOperandAtIndex(sinCache, 3) ||
33+
binder.s64IntegerAttr(interleaved, "interleaved", 0) ||
34+
binder.s64IntegerAttr(isPackedBatching, "is_packed_batching", 0) ||
35+
binder.s64IntegerAttr(numHeads, "num_heads", 0) ||
36+
binder.s64IntegerAttr(rotaryEmbeddingDim, "rotary_embedding_dim",
37+
0) ||
38+
binder.f32FloatAttr(scale, "scale", 1.0)) {
39+
return rewriter.notifyMatchFailure(binder.op,
40+
"Failed to get required inputs");
41+
}
42+
43+
Torch::ValueTensorType resultType;
44+
if (binder.tensorResultType(resultType)) {
45+
return rewriter.notifyMatchFailure(binder.op,
46+
"result type bind failure");
47+
}
48+
49+
Value cstInterleaved = rewriter.create<Torch::ConstantIntOp>(
50+
loc, rewriter.getI64IntegerAttr(interleaved));
51+
Value cstIsPackedBatching = rewriter.create<Torch::ConstantIntOp>(
52+
loc, rewriter.getI64IntegerAttr(isPackedBatching));
53+
Value cstNumHeads = rewriter.create<Torch::ConstantIntOp>(
54+
loc, rewriter.getI64IntegerAttr(numHeads));
55+
Value cstRotaryEmbeddingDim = rewriter.create<Torch::ConstantIntOp>(
56+
loc, rewriter.getI64IntegerAttr(rotaryEmbeddingDim));
57+
Value cstScale = rewriter.create<Torch::ConstantFloatOp>(
58+
loc, rewriter.getF64FloatAttr(scale));
59+
60+
rewriter.replaceOpWithNewOp<Torch::OnnxVariantRotaryEmbeddingOp>(
61+
binder.op, resultType, input, positionIds, cosCache, sinCache,
62+
cstInterleaved, cstIsPackedBatching, cstNumHeads,
63+
cstRotaryEmbeddingDim, cstScale);
64+
return success();
65+
});
66+
}

lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class ConvertTorchOnnxToTorch
5656
std::make_unique<OnnxCustomOpConversionPattern>(
5757
context, "onnx.",
5858
/*domainVersion=*/defaultOpsetVersion);
59+
populateComMicrosoftDomain(*defaultDomainPatterns);
5960
populateDefaultDomainAtoF(*defaultDomainPatterns);
6061
populateDefaultDomainGtoP(*defaultDomainPatterns);
6162
populateDefaultDomainQtoZ(*defaultDomainPatterns);

0 commit comments

Comments
 (0)