@@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
140
140
}
141
141
142
142
std::string padding;
143
+ int64_t paddingModeInt;
143
144
if (binder.customOpNameStringAttr (padding, " padding_mode" , " zeros" ))
144
145
return rewriter.notifyMatchFailure (binder.op ,
145
146
" padding_mode bind failure" );
146
- if (padding != " zeros" )
147
+ if (padding == " zeros" ) {
148
+ paddingModeInt = 0 ;
149
+ } else if (padding == " border" ) {
150
+ paddingModeInt = 1 ;
151
+ } else {
147
152
return rewriter.notifyMatchFailure (
148
- binder.op , " currently only padding_mode : zeros supported" );
153
+ binder.op ,
154
+ " currently only padding_mode : zeros and border supported" );
155
+ }
149
156
int64_t align;
150
157
if (binder.s64IntegerAttr (align, " align_corners" , 0 ))
151
158
return rewriter.notifyMatchFailure (binder.op ,
@@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
157
164
158
165
Value paddingMode = rewriter.create <Torch::ConstantIntOp>(
159
166
binder.getLoc (), rewriter.getType <Torch::IntType>(),
160
- rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
167
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
168
+ paddingModeInt));
161
169
162
170
bool alignMode = align;
163
171
Value alignCorners = rewriter.create <Torch::ConstantBoolOp>(
0 commit comments