@@ -333,6 +333,120 @@ bool hasStaticSpatialDims(Value v) {
333333 return llvm::none_of (Ds, ShapedType::isDynamic);
334334}
335335
336+ // In the following pattern, a SequenceAt can be replaced with Split
337+ // %seq = onnx.SplitToSequence(%input, %split) {%axis : }
338+ // %res = onnx.SequenceAt(%seq, %position)
339+ // We just try to avoid using the sequence related ops, which are less
340+ // optimized, or even not implemented in onnx-mlir.
341+ // In the targeted use case, %split and %position are constant scalar and the
342+ // tensor of %input and %res have static shape.
343+ // This condition greatly reduces the complexity of code generation to replace
344+ // SequenceAt with split op
345+ // %res = onnx.Split(%input, onnx.expand(%split, %input.shape()[%axis]))
346+ // {%axis : } : %position
347+ // onnx.expand(%split, %input.shape()[%axis]) can be a constant under the
348+ // assumed condition.
349+ // Here %position has to be compiler time constant.
350+ // For multiple SequenceAt from the same SplitToSequence result, the onnx.split
351+ // for different SequenceAt are expected to be merged by optimization.
352+ // Alternatively, Slice can be used
353+ // %res = onnx.Slice(%input, %start, %end, %step)
354+ // The start, and end for slice will be onnx.constant:
355+ // start: %position*%split for %axis, 0 for other dimensionis
356+ // end: (%positiion+1)*%split for %axis, upper bound for other dimension
357+ // step: 1 for all dimensions
358+ // The split approach may have better performance than the alternative slice
359+ // approach, because the slicing is done separately.
360+
361+ bool canSequenceAtBeReplaced (Value sequenceAtResult) {
362+ if (!hasStaticShape (sequenceAtResult.getType ()))
363+ return false ;
364+
365+ ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp <ONNXSequenceAtOp>();
366+
367+ Value inputSequence = op.getInputSequence ();
368+ Value position = op.getPosition ();
369+
370+ if (!isDenseONNXConstant (position))
371+ return false ;
372+
373+ // Input sequence should be defined with SplitToSequence
374+ ONNXSplitToSequenceOp splitToSequence =
375+ inputSequence.getDefiningOp <ONNXSplitToSequenceOp>();
376+ if (!splitToSequence)
377+ return false ;
378+
379+ // Check the pattern of the SplitToSequence op
380+ Value input = splitToSequence.getInput ();
381+ if (!hasStaticShape (input.getType ()))
382+ return false ;
383+ Value split = splitToSequence.getSplit ();
384+ if (!isScalarConstantTensor (split))
385+ return false ;
386+
387+ return true ;
388+ }
389+
390+ Value replaceSequenceAt (
391+ PatternRewriter &rewriter, Location loc, Value sequenceAtResult) {
392+ ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp <ONNXSequenceAtOp>();
393+
394+ Value inputSequence = op.getInputSequence ();
395+ Value position = op.getPosition ();
396+
397+ ONNXConstantOp positionConstant =
398+ mlir::cast<ONNXConstantOp>(position.getDefiningOp ());
399+ int64_t positionInt = getScalarValue<int64_t >(positionConstant);
400+
401+ ONNXSplitToSequenceOp splitToSequence =
402+ mlir::cast<ONNXSplitToSequenceOp>(inputSequence.getDefiningOp ());
403+
404+ Value input = splitToSequence.getInput ();
405+ Value split = splitToSequence.getSplit ();
406+
407+ ONNXConstantOp splitConstant =
408+ mlir::cast<ONNXConstantOp>(split.getDefiningOp ());
409+ int64_t splitInt = getScalarValue<int64_t >(splitConstant);
410+ int64_t axisInt = splitToSequence.getAxis ();
411+
412+ auto shape = getShape (input.getType ());
413+
414+ OnnxBuilder create (rewriter, loc);
415+
416+ Type sequenceElementType =
417+ mlir::cast<SeqType>(inputSequence.getType ()).getElementType ();
418+ mlir::SmallVector<mlir::Type, 4 > outputTypes (
419+ shape[axisInt] / splitInt, sequenceElementType);
420+ auto numSplit = create.constantInt64 (
421+ mlir::SmallVector<int64_t , 4 >(shape[axisInt] / splitInt, splitInt));
422+ auto resultRange = create.split (outputTypes, input, numSplit, axisInt);
423+ auto rawResult = resultRange[positionInt];
424+
425+ if (rawResult.getType () == sequenceAtResult.getType ())
426+ return rawResult;
427+
428+ // Temporary code for the error in the model generated by torch.onnx.export
429+ // The the dim of the reuslt of SequenceAt op is different from the element
430+ // type of the sequence..
431+ // My assumption is that the exporter is confused with squeeze and unsqueeze
432+ // followed by the SequenceAt. There are two cases in the model:
433+ // clang-format off
434+ // Case #1:
435+ // %16 = "onnx.SequenceAt"(%14, %15) {onnx_node_name = "n0"} :
436+ // (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>
437+ // %23 = "onnx.Unsqueeze"(%16, %22) {onnx_node_name = "n2"} :
438+ // (tensor<1x100xf32>, tensor<i64>) -> tensor<1x1x100xf32>
439+ // Case#2:
440+ // %67 = "onnx.SequenceAt"(%66, %15) {onnx_node_name = "n0"} :
441+ // (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
442+ // %71 = "onnx.Sigmoid"(%67) {onnx_node_name = "node_Sigmoid_60"} :
443+ // (tensor<1x1x100xf32>) -> tensor<1x1x100xf32>
444+ // clang-format on
445+ // Thus, the compiler squeeze the tensor if needed.
446+ return create.squeeze (
447+ sequenceAtResult.getType (), rawResult, create.constantInt64 (axisInt));
448+ }
449+
336450bool shouldDecomposeConvTransposeOp (Value convTransposeResult) {
337451 ONNXConvTransposeOp op =
338452 mlir::cast<ONNXConvTransposeOp>(convTransposeResult.getDefiningOp ());
@@ -1246,6 +1360,10 @@ void DecomposeONNXToONNXPass::runOnOperation() {
12461360 return !isConcatFuseMatched (op, shapeOp, transposeOp);
12471361 });
12481362
1363+ target.addDynamicallyLegalOp <ONNXSequenceAtOp>([](ONNXSequenceAtOp op) {
1364+ return !onnx_mlir::canSequenceAtBeReplaced (op.getResult ());
1365+ });
1366+
12491367 // Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs.
12501368 target.addDynamicallyLegalOp <ONNXConstantOp>([](ONNXConstantOp op) {
12511369 return !(op.getValueFloatAttr () || op.getValueFloatsAttr () ||
0 commit comments