@@ -873,4 +873,95 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
873873 rewriter.replaceOp (binder.op , operand);
874874 return success ();
875875 });
876+
877+ patterns.onOp (
878+ " Reshape" , 5 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
879+ Torch::ValueTensorType resultType;
880+ Value data;
881+ Value shape;
882+ int64_t allowzero;
883+ if (binder.tensorOperands (data, shape) ||
884+ binder.tensorResultType (resultType) ||
885+ binder.s64IntegerAttr (allowzero, " allowzero" , 0 ))
886+ return failure ();
887+ Torch::BaseTensorType shapeType =
888+ shape.getType ().cast <Torch::BaseTensorType>();
889+ SmallVector<Value> dimList;
890+ SmallVector<int64_t > selectSizes;
891+ selectSizes.push_back (1 );
892+ Type selectResultType = shapeType.getWithSizesAndDtype (
893+ llvm::ArrayRef (selectSizes), shapeType.getOptionalDtype ());
894+ auto shapeSizes =
895+ dyn_cast<Torch::ValueTensorType>(shape.getType ()).getSizes ();
896+ auto dataSizes =
897+ dyn_cast<Torch::ValueTensorType>(data.getType ()).getSizes ();
898+ Value zero = rewriter.create <Torch::ConstantIntOp>(
899+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
900+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
901+ if (allowzero == 0 ) {
902+ // convert shape (tensor) into torch int list while dealing with zero
903+ // vals
904+ for (int i = 0 ; i < shapeSizes[0 ]; i++) {
905+ // Go through the shape list and get each dim in the list
906+ Value selectIndex = rewriter.create <Torch::ConstantIntOp>(
907+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
908+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), i));
909+ Value extract = rewriter.create <Torch::AtenSelectIntOp>(
910+ binder.getLoc (), selectResultType, shape, zero, selectIndex);
911+ Value dim = rewriter.create <Torch::AtenItemOp>(
912+ binder.getLoc (), rewriter.getType <Torch::IntType>(), extract);
913+ // deal with zero axis values: replace with original dim value in
914+ // input
915+ Value isZero =
916+ rewriter.create <Torch::AtenEqIntOp>(binder.getLoc (), dim, zero);
917+ isZero =
918+ rewriter.create <Torch::AtenIntBoolOp>(binder.getLoc (), isZero);
919+ Value adjustment;
920+ int64_t inputDimsSize = dataSizes.size ();
921+ if (i < inputDimsSize) {
922+ adjustment = rewriter.create <Torch::ConstantIntOp>(
923+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
924+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
925+ dataSizes[i]));
926+ }
927+ // Will never have a 0 in the shape tensor input at an index out of
928+ // bounds of original input dims Therefore, no need to adjust
929+ else {
930+ adjustment = zero;
931+ }
932+ Value finalOffset = rewriter.create <Torch::AtenMulIntOp>(
933+ binder.getLoc (), isZero, adjustment);
934+ Value finalDim = rewriter.create <Torch::AtenAddIntOp>(
935+ binder.getLoc (), dim, finalOffset);
936+ dimList.push_back (finalDim);
937+ }
938+ Value dimValueList = rewriter.create <Torch::PrimListConstructOp>(
939+ binder.getLoc (),
940+ Torch::ListType::get (
941+ Torch::IntType::get (binder.op ->getContext ())),
942+ dimList);
943+ rewriter.replaceOpWithNewOp <Torch::AtenReshapeOp>(
944+ binder.op , resultType, data, dimValueList);
945+ return success ();
946+ }
947+ // convert axes (tensor) into torch int list
948+ for (int i = 0 ; i < shapeSizes[0 ]; i++) {
949+ // Go through the axes list and get each dim in the list
950+ Value selectIndex = rewriter.create <Torch::ConstantIntOp>(
951+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
952+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), i));
953+ Value extract = rewriter.create <Torch::AtenSelectIntOp>(
954+ binder.getLoc (), selectResultType, shape, zero, selectIndex);
955+ Value dim = rewriter.create <Torch::AtenItemOp>(
956+ binder.getLoc (), rewriter.getType <Torch::IntType>(), extract);
957+ dimList.push_back (dim);
958+ }
959+ Value dimValueList = rewriter.create <Torch::PrimListConstructOp>(
960+ binder.getLoc (),
961+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
962+ dimList);
963+ rewriter.replaceOpWithNewOp <Torch::AtenReshapeOp>(binder.op , resultType,
964+ data, dimValueList);
965+ return success ();
966+ });
876967}
0 commit comments