Skip to content

Commit ee75e8d

Browse files
authored
[MLIR][ONNX] Add OnnxToTorch support for Reshape Op (#2698)
This commit adds the OnnxToTorch support for Reshape op.
1 parent 0849fd0 commit ee75e8d

File tree

2 files changed

+281
-0
lines changed

2 files changed

+281
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,4 +873,95 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
873873
rewriter.replaceOp(binder.op, operand);
874874
return success();
875875
});
876+
877+
patterns.onOp(
878+
"Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
879+
Torch::ValueTensorType resultType;
880+
Value data;
881+
Value shape;
882+
int64_t allowzero;
883+
if (binder.tensorOperands(data, shape) ||
884+
binder.tensorResultType(resultType) ||
885+
binder.s64IntegerAttr(allowzero, "allowzero", 0))
886+
return failure();
887+
Torch::BaseTensorType shapeType =
888+
shape.getType().cast<Torch::BaseTensorType>();
889+
SmallVector<Value> dimList;
890+
SmallVector<int64_t> selectSizes;
891+
selectSizes.push_back(1);
892+
Type selectResultType = shapeType.getWithSizesAndDtype(
893+
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
894+
auto shapeSizes =
895+
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
896+
auto dataSizes =
897+
dyn_cast<Torch::ValueTensorType>(data.getType()).getSizes();
898+
Value zero = rewriter.create<Torch::ConstantIntOp>(
899+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
900+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
901+
if (allowzero == 0) {
902+
// convert shape (tensor) into torch int list while dealing with zero
903+
// vals
904+
for (int i = 0; i < shapeSizes[0]; i++) {
905+
// Go through the shape list and get each dim in the list
906+
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
907+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
908+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
909+
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
910+
binder.getLoc(), selectResultType, shape, zero, selectIndex);
911+
Value dim = rewriter.create<Torch::AtenItemOp>(
912+
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
913+
// deal with zero axis values: replace with original dim value in
914+
// input
915+
Value isZero =
916+
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
917+
isZero =
918+
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
919+
Value adjustment;
920+
int64_t inputDimsSize = dataSizes.size();
921+
if (i < inputDimsSize) {
922+
adjustment = rewriter.create<Torch::ConstantIntOp>(
923+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
924+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
925+
dataSizes[i]));
926+
}
927+
// Will never have a 0 in the shape tensor input at an index out of
928+
// bounds of original input dims Therefore, no need to adjust
929+
else {
930+
adjustment = zero;
931+
}
932+
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
933+
binder.getLoc(), isZero, adjustment);
934+
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
935+
binder.getLoc(), dim, finalOffset);
936+
dimList.push_back(finalDim);
937+
}
938+
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
939+
binder.getLoc(),
940+
Torch::ListType::get(
941+
Torch::IntType::get(binder.op->getContext())),
942+
dimList);
943+
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
944+
binder.op, resultType, data, dimValueList);
945+
return success();
946+
}
947+
// convert axes (tensor) into torch int list
948+
for (int i = 0; i < shapeSizes[0]; i++) {
949+
// Go through the axes list and get each dim in the list
950+
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
951+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
952+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
953+
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
954+
binder.getLoc(), selectResultType, shape, zero, selectIndex);
955+
Value dim = rewriter.create<Torch::AtenItemOp>(
956+
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
957+
dimList.push_back(dim);
958+
}
959+
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
960+
binder.getLoc(),
961+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
962+
dimList);
963+
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(binder.op, resultType,
964+
data, dimValueList);
965+
return success();
966+
});
876967
}

0 commit comments

Comments
 (0)