@@ -643,8 +643,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
643643 llvm::SmallVector<int64_t > axes;
644644 int64_t keepDims;
645645 int64_t noop_with_empty_axes;
646- if (binder.tensorOperand (data) ||
647- binder.tensorResultType (resultType) ||
646+ if (binder.tensorOperand (data) || binder.tensorResultType (resultType) ||
648647 binder.s64IntegerArrayAttr (axes, " axes" , 0 ) ||
649648 binder.s64IntegerAttr (keepDims, " keepdims" , 1 ) ||
650649 binder.s64IntegerAttr (noop_with_empty_axes, " noop_with_empty_axes" ,
@@ -1092,7 +1091,168 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
10921091 rewriter.replaceOp (binder.op , operand);
10931092 return success ();
10941093 });
1094+ patterns.onOp (
1095+ " Slice" , 13 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1096+ Torch::ValueTensorType resultTorchType;
1097+ Value operand, starts, ends;
1098+ // Handle if axes are not provided
1099+
1100+ if (binder.tensorOperandAtIndex (operand, 0 ) ||
1101+ binder.tensorOperandAtIndex (starts, 1 ) ||
1102+ binder.tensorOperandAtIndex (ends, 2 ) ||
1103+ binder.tensorResultType (resultTorchType)) {
1104+ return failure ();
1105+ }
1106+
1107+ auto context = rewriter.getContext ();
1108+ auto operandTorchTy = operand.getType ().cast <Torch::ValueTensorType>();
1109+ auto operandTy =
1110+ operandTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1111+
1112+ if (!operandTy)
1113+ return rewriter.notifyMatchFailure (
1114+ binder.op ,
1115+ " Expected tensor operator argument to be a ranked tensor type" );
1116+
1117+ auto startsTorchTy = starts.getType ().cast <Torch::ValueTensorType>();
1118+ auto startsTy =
1119+ startsTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1120+ int startSize = startsTy.getDimSize (0 );
1121+
1122+ auto endsTorchTy = ends.getType ().cast <Torch::ValueTensorType>();
1123+ auto endsTy =
1124+ endsTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1125+ int endSize = endsTy.getDimSize (0 );
1126+ auto resultTy =
1127+ resultTorchType.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1128+ if (!resultTy)
1129+ return rewriter.notifyMatchFailure (
1130+ binder.op , " Expected result type to be a ranked tensor type" );
1131+
1132+ Location loc = binder.getLoc ();
1133+
1134+ // Binding `axes` from its arguments or through a default value
1135+ Value axes;
1136+ if (binder.getNumOperands () >= 4 ) {
1137+ if (binder.tensorOperandAtIndex (axes, 3 )) {
1138+ return failure ();
1139+ }
1140+ } else {
1141+ // The default axes value is the range from 0 to the number of
1142+ // dimensions
1143+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1144+ auto defaultAxesType = Torch::ValueTensorType::get (
1145+ context, ArrayRef<int64_t >{operandTy.getRank ()},
1146+ rewriter.getIntegerType (64 , /* signed*/ 1 ));
1147+ Value arangeLength = rewriter.create <Torch::ConstantIntOp>(
1148+ loc, rewriter.getType <Torch::IntType>(),
1149+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
1150+ operandTy.getRank ()));
1151+ axes = rewriter.create <Torch::AtenArangeOp>(
1152+ loc, defaultAxesType, arangeLength, none, none, none, none);
1153+ }
1154+
1155+ // Binding `steps` from its arguments or through a default value
1156+ Value steps;
1157+ if (binder.getNumOperands () >= 5 ) {
1158+ if (binder.tensorOperandAtIndex (steps, 4 )) {
1159+ return failure ();
1160+ }
1161+ } else {
1162+ // The default `steps` value is a 1d tensor filled with ones with a
1163+ // size of the dimension of the operand
1164+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1165+ auto defaultStepsType = Torch::ValueTensorType::get (
1166+ context, ArrayRef<int64_t >{operandTy.getRank ()},
1167+ rewriter.getIntegerType (64 , /* signed*/ 1 ));
1168+ Value sizeStepInput = rewriter.create <Torch::ConstantIntOp>(
1169+ loc, rewriter.getType <Torch::IntType>(),
1170+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
1171+ operandTy.getRank ()));
1172+ Value sizeStepsInput = rewriter.create <Torch::PrimListConstructOp>(
1173+ loc,
1174+ Torch::ListType::get (
1175+ Torch::IntType::get (binder.op ->getContext ())),
1176+ sizeStepInput);
1177+ steps = rewriter.create <Torch::AtenOnesOp>(
1178+ loc, defaultStepsType, sizeStepsInput, none, none, none, none);
1179+ }
10951180
1181+ if (!(endsTy.getRank () == 1 && startsTy.getRank () == 1 &&
1182+ startSize == endSize))
1183+ return rewriter.notifyMatchFailure (
1184+ binder.op , " Expected the rank of starts and ends tensors to be 1 "
1185+ " and their dimensions to match" );
1186+
1187+ auto axesTorchTy = axes.getType ().cast <Torch::ValueTensorType>();
1188+ auto axesTy =
1189+ axesTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1190+ int64_t numAxes = axesTy.getDimSize (0 );
1191+
1192+ if (!(axesTy && numAxes == endSize))
1193+ return rewriter.notifyMatchFailure (
1194+ binder.op , " Axes should be the same size of starts and ends" );
1195+
1196+ auto stepsTy = steps.getType ()
1197+ .cast <Torch::ValueTensorType>()
1198+ .toBuiltinTensor ()
1199+ .dyn_cast <RankedTensorType>();
1200+
1201+ if (!(stepsTy && stepsTy.getDimSize (0 ) == endsTy.getDimSize (0 )))
1202+ return rewriter.notifyMatchFailure (
1203+ binder.op , " Steps should be the same size of starts and ends" );
1204+
1205+ Value zero = rewriter.create <Torch::ConstantIntOp>(
1206+ loc, rewriter.getType <Torch::IntType>(),
1207+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
1208+
1209+ auto select = [&](Value v, Value k) -> Value {
1210+ auto ty = v.getType ().cast <Torch::ValueTensorType>();
1211+ auto sel = rewriter.create <Torch::AtenIndexSelectOp>(
1212+ loc,
1213+ Torch::ValueTensorType::get (ty.getContext (), ArrayRef<int64_t >{1 },
1214+ ty.getOptionalDtype ()),
1215+ v, zero, k);
1216+ Value item = rewriter.create <Torch::AtenItemOp>(
1217+ loc, rewriter.getType <Torch::IntType>(), sel);
1218+ return item;
1219+ };
1220+
1221+ llvm::SmallVector<int64_t > intermediateShape (operandTy.getShape ());
1222+ for (int i = 0 , s = operandTy.getRank (); i < s; ++i) {
1223+ if (operandTy.getDimSize (i) != resultTy.getDimSize (i)) {
1224+ intermediateShape[i] = -1 ;
1225+ }
1226+ }
1227+ auto intermediateType = Torch::ValueTensorType::get (
1228+ context, intermediateShape, resultTorchType.getOptionalDtype ());
1229+ for (int i = 0 ; i < numAxes; ++i) {
1230+
1231+ Value k = rewriter.create <Torch::ConstantIntOp>(
1232+ loc, rewriter.getType <Torch::IntType>(),
1233+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), i));
1234+ Value kTensor = rewriter.create <Torch::PrimNumToTensorScalarOp>(
1235+ loc,
1236+ Torch::ValueTensorType::get (
1237+ context, ArrayRef<int64_t >{1 },
1238+ rewriter.getIntegerType (64 , /* signed*/ 1 )),
1239+ k);
1240+
1241+ Value start = select (starts, kTensor );
1242+ Value end = select (ends, kTensor );
1243+ Value axis = select (axes, kTensor );
1244+ Value step = select (steps, kTensor );
1245+
1246+ auto sliceType = intermediateType;
1247+ if (i == numAxes - 1 )
1248+ sliceType = resultTorchType;
1249+ operand = rewriter.create <Torch::AtenSliceTensorOp>(
1250+ loc, sliceType, operand, axis, start, end, step);
1251+ }
1252+
1253+ rewriter.replaceOp (binder.op , operand);
1254+ return success ();
1255+ });
10961256 patterns.onOp (
10971257 " Reshape" , 5 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
10981258 Torch::ValueTensorType resultType;
0 commit comments