Skip to content

Commit 9ab2a15

Browse files
authored
[Torch] emit upsample_bilinear2d(.vec) ops (llvm#3834)
1 parent 2b01f8b commit 9ab2a15

File tree

5 files changed

+109
-0
lines changed

5 files changed

+109
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14093,6 +14093,59 @@ def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
1409314093
}];
1409414094
}
1409514095

14096+
def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [
14097+
AllowsTypeRefinement,
14098+
HasValueSemantics,
14099+
ReadOnly
14100+
]> {
14101+
let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`";
14102+
let arguments = (ins
14103+
AnyTorchTensorType:$self,
14104+
AnyTorchListOfTorchIntType:$output_size,
14105+
Torch_BoolType:$align_corners,
14106+
AnyTorchOptionalFloatType:$scales_h,
14107+
AnyTorchOptionalFloatType:$scales_w
14108+
);
14109+
let results = (outs
14110+
AnyTorchOptionalTensorType:$result
14111+
);
14112+
let hasCustomAssemblyFormat = 1;
14113+
let extraClassDefinition = [{
14114+
ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) {
14115+
return parseDefaultTorchOp(parser, result, 5, 1);
14116+
}
14117+
void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) {
14118+
printDefaultTorchOp(printer, *this, 5, 1);
14119+
}
14120+
}];
14121+
}
14122+
14123+
def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [
14124+
AllowsTypeRefinement,
14125+
HasValueSemantics,
14126+
ReadOnly
14127+
]> {
14128+
let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`";
14129+
let arguments = (ins
14130+
AnyTorchTensorType:$input,
14131+
AnyTorchOptionalListOfTorchIntType:$output_size,
14132+
Torch_BoolType:$align_corners,
14133+
AnyTorchOptionalListOfTorchFloatType:$scale_factors
14134+
);
14135+
let results = (outs
14136+
AnyTorchOptionalTensorType:$result
14137+
);
14138+
let hasCustomAssemblyFormat = 1;
14139+
let extraClassDefinition = [{
14140+
ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
14141+
return parseDefaultTorchOp(parser, result, 4, 1);
14142+
}
14143+
void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) {
14144+
printDefaultTorchOp(printer, *this, 4, 1);
14145+
}
14146+
}];
14147+
}
14148+
1409614149
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
1409714150
AllowsTypeRefinement,
1409814151
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11043,6 +11043,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1104311043
" }\n"
1104411044
" return %10 : !torch.list<int>\n"
1104511045
" }\n"
11046+
" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
11047+
" %int0 = torch.constant.int 0\n"
11048+
" %int1 = torch.constant.int 1\n"
11049+
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
11050+
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
11051+
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
11052+
" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
11053+
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
11054+
" return %4 : !torch.list<int>\n"
11055+
" }\n"
11056+
" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n"
11057+
" %0 = call @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0, %arg1, %arg3) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<float>>) -> !torch.list<int>\n"
11058+
" return %0 : !torch.list<int>\n"
11059+
" }\n"
1104611060
" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
1104711061
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1104811062
" return %0#1 : !torch.int\n"
@@ -12576,6 +12590,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1257612590
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1257712591
" return %0#1 : !torch.int\n"
1257812592
" }\n"
12593+
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.int {\n"
12594+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12595+
" return %0#1 : !torch.int\n"
12596+
" }\n"
12597+
" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<list<float>>) -> !torch.int {\n"
12598+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12599+
" return %0#1 : !torch.int\n"
12600+
" }\n"
1257912601
" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
1258012602
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1258112603
" return %0#1 : !torch.int\n"

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@
531531
"_SoftmaxModule_basic",
532532
"UpSampleNearest2dDynamicFactor_basic",
533533
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
534+
# torch export: RuntimeError: cannot mutate tensors with frozen storage
535+
"ElementwiseRreluWithNoiseTrainModule_basic",
536+
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
534537
}
535538

536539
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
@@ -979,6 +982,9 @@
979982
# materialization callback produced value of incorrect type failed
980983
"ReduceMaxAlongDimUnsignedInt_basic",
981984
"ReduceMinAlongDimUnsignedInt_basic",
985+
# torch export: RuntimeError: cannot mutate tensors with frozen storage
986+
"ElementwiseRreluWithNoiseTrainModule_basic",
987+
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
982988
}
983989

984990
STABLEHLO_PASS_SET = {

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,6 +2342,20 @@ def aten〇upsample_nearest2d〇vec〡shape(input: List[int], output_size: Optio
23422342
assert scale_factors is not None
23432343
return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])]
23442344

2345+
@check_shape_function([
2346+
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True)
2347+
])
2348+
def aten〇upsample_bilinear2d〡shape(self: List[int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
2349+
return [self[0], self[1], output_size[0], output_size[1]]
2350+
2351+
@check_shape_function([
2352+
Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True, None),
2353+
Invocation(TensorOfShape(1, 3, 10, 9), None, True, [2.0, 2.3]),
2354+
Invocation(TensorOfShape(1, 3, 5, 6), None, True, [2.5, 1.0])
2355+
])
2356+
def aten〇upsample_bilinear2d〇vec〡shape(input: List[int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> List[int]:
2357+
return aten〇upsample_nearest2d〇vec〡shape(input, output_size, scale_factors)
2358+
23452359
# ==============================================================================
23462360
# Dtype Functions
23472361
# ==============================================================================
@@ -3570,6 +3584,16 @@ def aten〇upsample_nearest2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], o
35703584
self_rank, self_dtype = input_rank_dtype
35713585
return self_dtype
35723586

3587+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True))
3588+
def aten〇upsample_bilinear2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int:
3589+
self_rank, self_dtype = self_rank_dtype
3590+
return self_dtype
3591+
3592+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True, scale_factors=None))
3593+
def aten〇upsample_bilinear2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> int:
3594+
self_rank, self_dtype = input_rank_dtype
3595+
return self_dtype
3596+
35733597
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]))
35743598
def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int:
35753599
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,10 @@ def emit_with_mutating_variants(key, **kwargs):
10131013
emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
10141014
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
10151015
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
1016+
emit(
1017+
"aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)"
1018+
)
1019+
emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)")
10161020
emit(
10171021
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)"
10181022
)

0 commit comments

Comments
 (0)