Skip to content

Commit 88c438a

Browse files
committed
[OnnxToTorch][GridSample] Add support for border padding mode
1 parent 63247ae commit 88c438a

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
140140
}
141141

142142
std::string padding;
143+
int64_t paddingModeInt;
143144
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
144145
return rewriter.notifyMatchFailure(binder.op,
145146
"padding_mode bind failure");
146-
if (padding != "zeros")
147+
if (padding == "zeros") {
148+
paddingModeInt = 0;
149+
} else if (padding == "border") {
150+
paddingModeInt = 1;
151+
} else {
147152
return rewriter.notifyMatchFailure(
148-
binder.op, "currently only padding_mode : zeros supported");
153+
binder.op,
154+
"currently only padding_mode : zeros and border supported");
155+
}
149156
int64_t align;
150157
if (binder.s64IntegerAttr(align, "align_corners", 0))
151158
return rewriter.notifyMatchFailure(binder.op,
@@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
157164

158165
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
159166
binder.getLoc(), rewriter.getType<Torch::IntType>(),
160-
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
167+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
168+
paddingModeInt));
161169

162170
bool alignMode = align;
163171
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(

0 commit comments

Comments
 (0)