@@ -794,6 +794,148 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
794794 return success ();
795795 });
796796
797+ // split with fixed-size parts
798+ // Arguments:
799+ // - input: the tensor to split
800+ // Attributes:
801+ // - axis: the axis along which to split the input
802+ // - num_outputs: the number of outputs to produce
803+ // Outputs:
804+ // - outputs: the produced outputs. Variadic with num_outputs elements.
805+ // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
806+ // tensors
807+ // so we need to unpack the list
808+ patterns.onOp (
809+ " Split" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
810+ Value self;
811+ int64_t axis;
812+ int64_t num_outputs;
813+ if (binder.tensorOperand (self))
814+ return rewriter.notifyMatchFailure (
815+ binder.op , " Not converting to AtenSplitTensorOp due to input "
816+ " tensor mismatch" );
817+ if (binder.s64IntegerAttr (axis, " axis" , 0 ))
818+ return rewriter.notifyMatchFailure (binder.op ,
819+ " Failed to get axis attribute" );
820+ if (binder.s64IntegerAttr (num_outputs, " num_outputs" , 0 ))
821+ return rewriter.notifyMatchFailure (
822+ binder.op , " Failed to get num_outputs attribute" );
823+
824+ auto result0Ty =
825+ binder.op ->getResult (0 ).getType ().cast <Torch::ValueTensorType>();
826+ auto selfTy = self.getType ().cast <Torch::ValueTensorType>();
827+
828+ int64_t dim = axis;
829+ if (dim < 0 )
830+ dim += selfTy.getSizes ().size ();
831+
832+ // set intermediate shape to the shape of the first result
833+ // if the results are of different shapes
834+ // set the splitted axis to variable shape
835+ llvm::SmallVector<int64_t > intermediateShape (result0Ty.getSizes ());
836+ for (auto result : binder.op ->getResultTypes ()) {
837+ int64_t d = result.cast <Torch::ValueTensorType>().getSizes ()[dim];
838+ intermediateShape[dim] = d == intermediateShape[dim] ? d : -1 ;
839+ }
840+
841+ Value dimValue = rewriter.create <Torch::ConstantIntOp>(
842+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
843+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), dim));
844+
845+ Value splitSize = rewriter.create <Torch::ConstantIntOp>(
846+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
847+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), num_outputs));
848+
849+ // TODO: Attempting to use the shape expected by the ONNX mlir as ground
850+ // truth. For now just use dynamic shapes.
851+ auto resultOuterType =
852+ Torch::ListType::get (rewriter.getType <Torch::ValueTensorType>(
853+ /* std::optional<llvm::ArrayRef<int64_t>>=*/ intermediateShape,
854+ result0Ty.getOptionalDtype ()));
855+ Torch::AtenSplitTensorOp new_op =
856+ rewriter.create <Torch::AtenSplitTensorOp>(
857+ binder.getLoc (), resultOuterType, self, splitSize, dimValue);
858+
859+ // the onnx op is variadic with multiple results, but AtenSplitWithSizes
860+ // outputs a list so we need to unpack the list
861+ rewriter.replaceOpWithNewOp <Torch::PrimListUnpackOp>(
862+ binder.op , binder.op ->getResults ().getType (), new_op.getResult ());
863+
864+ return success ();
865+ });
866+
867+ // split with variable parts
868+ // Arguments:
869+ // - input: the tensor to split
870+ // - split: the sizes of the splits to be produced
871+ // Attributes:
872+ // - axis: the axis along which to split the input
873+ // - num_outputs: the number of outputs to produce
874+ // Outputs:
875+ // - outputs: the produced outputs. Variadic with num_outputs elements.
876+ // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
877+ // tensors
878+ // so we need to unpack the list
879+ patterns.onOp (
880+ " Split" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
881+ Value self;
882+ Value split;
883+ int64_t axis;
884+ int64_t num_outputs;
885+ if (binder.tensorOperandAtIndex (self, 0 ) ||
886+ binder.tensorOperandAtIndex (split, 1 ))
887+ return rewriter.notifyMatchFailure (
888+ binder.op , " Not converting to AtenSplitWithSizesOp due to input "
889+ " tensor mismatch" );
890+ if (binder.s64IntegerAttr (axis, " axis" , 0 ))
891+ return rewriter.notifyMatchFailure (binder.op ,
892+ " Failed to get axis attribute" );
893+ if (binder.s64IntegerAttr (num_outputs, " num_outputs" , 0 ))
894+ return rewriter.notifyMatchFailure (
895+ binder.op , " Failed to get num_outputs attribute" );
896+
897+ auto result0Ty =
898+ binder.op ->getResult (0 ).getType ().cast <Torch::ValueTensorType>();
899+ auto selfTy =
900+ cast<Torch::ValueTensorType>(binder.op ->getOperand (0 ).getType ());
901+
902+ int64_t dim = axis;
903+ if (dim < 0 )
904+ dim += selfTy.getSizes ().size ();
905+
906+ llvm::SmallVector<int64_t > intermediateShape (result0Ty.getSizes ());
907+ for (auto result : binder.op ->getResultTypes ()) {
908+ int64_t d = result.cast <Torch::ValueTensorType>().getSizes ()[dim];
909+ intermediateShape[dim] = d == intermediateShape[dim] ? d : -1 ;
910+ }
911+
912+ Torch::PrimTolistOp splitToList = rewriter.create <Torch::PrimTolistOp>(
913+ binder.getLoc (),
914+ Torch::ListType::get (rewriter.getType <Torch::IntType>()), split);
915+
916+ Value dimValue = rewriter.create <Torch::ConstantIntOp>(
917+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
918+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), dim));
919+
920+ // TODO: Attempting to use the shape expected by the ONNX mlir as ground
921+ // truth. For now just use dynamic shapes.
922+ auto resultOuterType =
923+ Torch::ListType::get (rewriter.getType <Torch::ValueTensorType>(
924+ /* std::optional<llvm::ArrayRef<int64_t>>=*/ intermediateShape,
925+ result0Ty.getOptionalDtype ()));
926+ Torch::AtenSplitWithSizesOp new_op =
927+ rewriter.create <Torch::AtenSplitWithSizesOp>(
928+ binder.getLoc (), resultOuterType, self,
929+ splitToList.getResult (0 ), dimValue);
930+
931+ // the onnx op is variadic with multiple results, but AtenSplitWithSizes
932+ // outputs a list so we need to unpack the list
933+ rewriter.replaceOpWithNewOp <Torch::PrimListUnpackOp>(
934+ binder.op , binder.op ->getResults ().getType (), new_op.getResult ());
935+
936+ return success ();
937+ });
938+
797939 patterns.onOp (" Tan" , 7 ,
798940 [](OpBinder binder, ConversionPatternRewriter &rewriter) {
799941 Torch::ValueTensorType resultType;
0 commit comments