Skip to content

Commit 45f07d5

Browse files
Transform SequenceAt to split for special cases (#3018)
* implement Signed-off-by: chentong319 <[email protected]> * test case Signed-off-by: chentong319 <[email protected]> * format Signed-off-by: chentong319 <[email protected]> * fix Signed-off-by: chentong319 <[email protected]> --------- Signed-off-by: chentong319 <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent e801b36 commit 45f07d5

File tree

5 files changed

+181
-2
lines changed

5 files changed

+181
-2
lines changed

src/Dialect/ONNX/DialectBuilder.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ TensorType OnnxBuilder::toTensor(Type input) const {
444444
}
445445

446446
TypeRange OnnxBuilder::toTensors(TypeRange inputs) const {
447-
assert(inputs.size() >= 2 && "Expect at least two inputs");
448447
if (llvm::all_of(inputs, [](Type t) { return (mlir::isa<TensorType>(t)); }))
449448
return inputs;
450449
assert(llvm::all_of(inputs, [](Type t) {

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ void ArrayAttrIntVals(ArrayAttr a, mlir::SmallVectorImpl<int64_t> &i) {
307307

308308
ElementsAttr getElementAttributeFromONNXValue(Value value) {
309309
ONNXConstantOp constantOp = getONNXConstantOp(value);
310-
if (constantOp)
310+
// In case the ConstantOp has not been normalized yet
311+
if (constantOp && constantOp.getValueAttr())
311312
return mlir::dyn_cast<ElementsAttr>(constantOp.getValueAttr());
312313
return nullptr;
313314
}

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,120 @@ bool hasStaticSpatialDims(Value v) {
333333
return llvm::none_of(Ds, ShapedType::isDynamic);
334334
}
335335

336+
// In the following pattern, a SequenceAt can be replaced with Split
337+
// %seq = onnx.SplitToSequence(%input, %split) {%axis : }
338+
// %res = onnx.SequenceAt(%seq, %position)
339+
// We just try to avoid using the sequence related ops, which are less
340+
// optimized, or even not implemented in onnx-mlir.
341+
// In the targeted use case, %split and %position are constant scalar and the
342+
// tensor of %input and %res have static shape.
343+
// This condition greatly reduces the complexity of code generation to replace
344+
// SequenceAt with split op
345+
// %res = onnx.Split(%input, onnx.expand(%split, %input.shape()[%axis]))
346+
// {%axis : } : %position
347+
// onnx.expand(%split, %input.shape()[%axis]) can be a constant under the
348+
// assumed condition.
349+
// Here %position has to be compiler time constant.
350+
// For multiple SequenceAt from the same SplitToSequence result, the onnx.split
351+
// for different SequenceAt are expected to be merged by optimization.
352+
// Alternatively, Slice can be used
353+
// %res = onnx.Slice(%input, %start, %end, %step)
354+
// The start, and end for slice will be onnx.constant:
355+
// start: %position*%split for %axis, 0 for other dimensionis
356+
// end: (%positiion+1)*%split for %axis, upper bound for other dimension
357+
// step: 1 for all dimensions
358+
// The split approach may have better performance than the alternative slice
359+
// approach, because the slicing is done separately.
360+
361+
bool canSequenceAtBeReplaced(Value sequenceAtResult) {
362+
if (!hasStaticShape(sequenceAtResult.getType()))
363+
return false;
364+
365+
ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>();
366+
367+
Value inputSequence = op.getInputSequence();
368+
Value position = op.getPosition();
369+
370+
if (!isDenseONNXConstant(position))
371+
return false;
372+
373+
// Input sequence should be defined with SplitToSequence
374+
ONNXSplitToSequenceOp splitToSequence =
375+
inputSequence.getDefiningOp<ONNXSplitToSequenceOp>();
376+
if (!splitToSequence)
377+
return false;
378+
379+
// Check the pattern of the SplitToSequence op
380+
Value input = splitToSequence.getInput();
381+
if (!hasStaticShape(input.getType()))
382+
return false;
383+
Value split = splitToSequence.getSplit();
384+
if (!isScalarConstantTensor(split))
385+
return false;
386+
387+
return true;
388+
}
389+
390+
Value replaceSequenceAt(
391+
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) {
392+
ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>();
393+
394+
Value inputSequence = op.getInputSequence();
395+
Value position = op.getPosition();
396+
397+
ONNXConstantOp positionConstant =
398+
mlir::cast<ONNXConstantOp>(position.getDefiningOp());
399+
int64_t positionInt = getScalarValue<int64_t>(positionConstant);
400+
401+
ONNXSplitToSequenceOp splitToSequence =
402+
mlir::cast<ONNXSplitToSequenceOp>(inputSequence.getDefiningOp());
403+
404+
Value input = splitToSequence.getInput();
405+
Value split = splitToSequence.getSplit();
406+
407+
ONNXConstantOp splitConstant =
408+
mlir::cast<ONNXConstantOp>(split.getDefiningOp());
409+
int64_t splitInt = getScalarValue<int64_t>(splitConstant);
410+
int64_t axisInt = splitToSequence.getAxis();
411+
412+
auto shape = getShape(input.getType());
413+
414+
OnnxBuilder create(rewriter, loc);
415+
416+
Type sequenceElementType =
417+
mlir::cast<SeqType>(inputSequence.getType()).getElementType();
418+
mlir::SmallVector<mlir::Type, 4> outputTypes(
419+
shape[axisInt] / splitInt, sequenceElementType);
420+
auto numSplit = create.constantInt64(
421+
mlir::SmallVector<int64_t, 4>(shape[axisInt] / splitInt, splitInt));
422+
auto resultRange = create.split(outputTypes, input, numSplit, axisInt);
423+
auto rawResult = resultRange[positionInt];
424+
425+
if (rawResult.getType() == sequenceAtResult.getType())
426+
return rawResult;
427+
428+
// Temporary code for the error in the model generated by torch.onnx.export
429+
// The the dim of the reuslt of SequenceAt op is different from the element
430+
// type of the sequence..
431+
// My assumption is that the exporter is confused with squeeze and unsqueeze
432+
// followed by the SequenceAt. There are two cases in the model:
433+
// clang-format off
434+
// Case #1:
435+
// %16 = "onnx.SequenceAt"(%14, %15) {onnx_node_name = "n0"} :
436+
// (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>
437+
// %23 = "onnx.Unsqueeze"(%16, %22) {onnx_node_name = "n2"} :
438+
// (tensor<1x100xf32>, tensor<i64>) -> tensor<1x1x100xf32>
439+
// Case#2:
440+
// %67 = "onnx.SequenceAt"(%66, %15) {onnx_node_name = "n0"} :
441+
// (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
442+
// %71 = "onnx.Sigmoid"(%67) {onnx_node_name = "node_Sigmoid_60"} :
443+
// (tensor<1x1x100xf32>) -> tensor<1x1x100xf32>
444+
// clang-format on
445+
// Thus, the compiler squeeze the tensor if needed.
446+
return create.squeeze(
447+
sequenceAtResult.getType(), rawResult, create.constantInt64(axisInt));
448+
}
449+
336450
bool shouldDecomposeConvTransposeOp(Value convTransposeResult) {
337451
ONNXConvTransposeOp op =
338452
mlir::cast<ONNXConvTransposeOp>(convTransposeResult.getDefiningOp());
@@ -1246,6 +1360,10 @@ void DecomposeONNXToONNXPass::runOnOperation() {
12461360
return !isConcatFuseMatched(op, shapeOp, transposeOp);
12471361
});
12481362

1363+
target.addDynamicallyLegalOp<ONNXSequenceAtOp>([](ONNXSequenceAtOp op) {
1364+
return !onnx_mlir::canSequenceAtBeReplaced(op.getResult());
1365+
});
1366+
12491367
// Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs.
12501368
target.addDynamicallyLegalOp<ONNXConstantOp>([](ONNXConstantOp op) {
12511369
return !(op.getValueFloatAttr() || op.getValueFloatsAttr() ||

src/Dialect/ONNX/Transforms/Decompose.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def createScalarDenseAttrRank0
7171
def ReshapeElementsAttrToRank0 : NativeCodeCall<
7272
"onnx_mlir::OnnxElementsAttrBuilder($0.getContext()).reshape(cast<ElementsAttr>($0), {})">;
7373

74+
def ReplaceSequenceAt : NativeCodeCall<
75+
"onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">;
76+
77+
def CanSequenceAtBeReplaced :
78+
Constraint<CPred<"::onnx_mlir::canSequenceAtBeReplaced($_self)">, "check whether the SequenceAt can be replaced with split">;
79+
7480
// Create a DenseElementsAttr from a single attribute.
7581
def createDenseArrayAttrFromSingleAttr
7682
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">;
@@ -620,4 +626,16 @@ def ConstantOpNormalizationPattern6: Pat<
620626
[(AttributeIsNotNull:$stringsAttr)]
621627
>;
622628

629+
// Optimize for the pattern coming from torch.nn.LSTM exported from pytorch
630+
// %32 = "onnx.SplitToSequence"(%30, %27) {axis = 0 : si64, keepdims = 0 : si64, onnx_node_name = "n1"} : (tensor<1x1x100xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
631+
// %33 = "onnx.SequenceAt"(%32, %26) {onnx_node_name = "n0"} : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>
632+
// When shape and size/axis related value are constant, this sequence of code
633+
// can be translated into onnx.slice
634+
635+
def ReplaceSequenceAtPattern: Pat<
636+
(ONNXSequenceAtOp:$res $seq, $position),
637+
(ReplaceSequenceAt $res),
638+
[(CanSequenceAtBeReplaced:$res)]
639+
>;
640+
623641
#endif // ONNX_DECOMPOSE
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
// RUN: onnx-mlir-opt --decompose-onnx --canonicalize %s -split-input-file | FileCheck %s
3+
4+
// -----
5+
6+
// Test one pattern in lstm_no_data.onnx.
7+
// The type of output of SequenceAt is not the same as the element type
8+
// of the input sequence
9+
func.func @sequence_at_squeezed(%arg0 : tensor<1x1x100xf32>) -> tensor<1x100xf32> {
10+
%26 = onnx.Constant dense<0> : tensor<i64>
11+
%27 = onnx.Constant dense<1> : tensor<i64>
12+
%32 = "onnx.SplitToSequence"(%arg0, %27) {axis = 0 : si64, keepdims = 0 : si64} : (tensor<1x1x100xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
13+
%33 = "onnx.SequenceAt"(%32, %26) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>
14+
return %33: tensor<1x100xf32>
15+
// CHECK-LABEL: func.func @sequence_at_squeezed
16+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x100xf32>) -> tensor<1x100xf32> {
17+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<1xi64>
18+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<1xi64>
19+
// CHECK: [[VAR_2_:%.+]] = "onnx.Split"([[PARAM_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x1x100xf32>
20+
// CHECK: [[VAR_3_:%.+]] = "onnx.Squeeze"([[VAR_2_]], [[VAR_0_]]) : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x100xf32>
21+
// CHECK: return [[VAR_3_]] : tensor<1x100xf32>
22+
// CHECK: }
23+
}
24+
25+
func.func @sequence_at_multi(%arg0 : tensor<1x1x400xf32>) -> tensor<1x1x100xf32> {
26+
%15 = onnx.Constant dense<0> : tensor<i64>
27+
%38 = onnx.Constant dense<1> : tensor<i64>
28+
%65 = onnx.Constant dense<100> : tensor<i64>
29+
%66 = "onnx.SplitToSequence"(%arg0, %65) {axis = 2 : si64, keepdims = 1 : si64} : (tensor<1x1x400xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
30+
%67 = "onnx.SequenceAt"(%66, %15) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
31+
%68 = "onnx.SequenceAt"(%66, %38) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
32+
%40 = "onnx.Add"(%67, %68) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>
33+
return %40: tensor<1x1x100xf32>
34+
// CHECK-LABEL: func.func @sequence_at_multi
35+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x400xf32>) -> tensor<1x1x100xf32> {
36+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<100> : tensor<4xi64>
37+
// CHECK-DAG: [[VAR_1_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
38+
// CHECK-DAG: [[VAR_2_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
39+
// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_1_]]#0, [[VAR_2_]]#1) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>
40+
// CHECK: return [[VAR_3_]] : tensor<1x1x100xf32>
41+
// CHECK: }
42+
}
43+

0 commit comments

Comments
 (0)