From 2b2201e29d0bde3d49990760181ab6eafc6a9d5a Mon Sep 17 00:00:00 2001 From: AaronStGeorge Date: Wed, 19 Feb 2025 19:50:40 -0800 Subject: [PATCH 1/2] Re-enable torch-adjust-calling-conventions tests --- .../Transforms/AdjustCallingConventions.cpp | 37 ++++- .../Torch/adjust-calling-conventions.mlir | 131 +++++++++--------- 2 files changed, 101 insertions(+), 67 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index e8b0d6b0364c..2e046c7559da 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -164,11 +164,44 @@ class AdjustCallingConventionForReturn public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // The dialect conversion framework inserts unrealized conversion casts to + // materialize legal types from illegal types. For example, for input IR + // like + // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, + // torch.tensor -> !torch.tuple + // return %1 : !torch.tuple + // at this stage in the conversion process we'll have something like + // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, + // !torch.tensor -> !torch.tuple + // %2 = builtin.unrealized_conversion_cast %1 : + // !torch.tuple to !torch.tensor + // %3 = builtin.unrealized_conversion_cast %1 : + // !torch.tuple to !torch.tensor + // return %2, %3 : !torch.tensor, !torch.tensor + // + // Here we map back to the original torch.prim.TupleConstruct's + SmallVector flatOperands; + for (const auto &vals : adaptor.getOperands()) { + for (const auto &operand : vals) { + // A block argument won't have a defining op. + if (operand.getDefiningOp() && + isa(operand.getDefiningOp())) { + auto definingOp = operand.getDefiningOp()->getOperand(0); + // de-duplicate + if (std::find(flatOperands.begin(), flatOperands.end(), definingOp) == + flatOperands.end()) { + flatOperands.push_back(definingOp); + } + } else { + flatOperands.push_back(operand); + } + } + } SmallVector newOperands; - for (auto operand : adaptor.getOperands()) { + for (auto operand : flatOperands) { if (!operand) continue; if (isa(operand.getType())) diff --git a/test/Dialect/Torch/adjust-calling-conventions.mlir b/test/Dialect/Torch/adjust-calling-conventions.mlir index 455a8e847486..992b60271327 100644 --- a/test/Dialect/Torch/adjust-calling-conventions.mlir +++ b/test/Dialect/Torch/adjust-calling-conventions.mlir @@ -9,6 +9,8 @@ func.func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?] return %arg0 : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @no_type_bound( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor { // CHECK: return %[[ARG]] : !torch.tensor @@ -16,6 +18,8 @@ func.func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor { return %arg0 : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @call( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor @@ -29,71 +33,68 @@ func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?], return %arg0 : !torch.tensor } -// COM: func.func @none_return() { -// COM: %[[NONE:.*]] = torch.constant.none -// COM: return -// func.func @none_return() -> !torch.none { -// %1 = torch.constant.none -// return %1 : !torch.none -// } +// ----- + +// CHECK-LABEL: func.func @none_return() { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: return +func.func @none_return() -> !torch.none { + %1 = torch.constant.none + return %1 : !torch.none +} + +// CHECK-LABEL: func.func @none_call_return() { +// CHECK: call @none_return() : () -> () +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> () +// CHECK: return +func.func @none_call_return() { + %0 = call @none_return() : () -> !torch.none + "test.use"(%0) : (!torch.none) -> () + return +} -// COM: func.func @none_call_return() { -// COM: call @none_return() : () -> () -// COM: %[[NONE:.*]] = torch.constant.none -// COM: "test.use"(%[[NONE]]) : (!torch.none) -> () -// COM: return -// func.func @none_call_return() { -// %0 = call @none_return() : () -> !torch.none -// "test.use"(%0) : (!torch.none) -> () -// return -// } +// ----- -// COM: func.func @tuple_return( -// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : -// COM: !torch.tensor, !torch.tensor -> !torch.tuple -// COM: %[[CST0:.*]] = torch.constant.int 0 -// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: %[[CST1:.*]] = torch.constant.int 1 -// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -// func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, -// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { -// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple -// return %1 : !torch.tuple -// } +// CHECK-LABEL: func.func @tuple_return( +// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : !torch.tensor, !torch.tensor -> !torch.tuple +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, + %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { + %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple + return %1 : !torch.tuple +} -// COM: func.func @call_tuple_return( -// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// COM: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// COM: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// COM: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// COM: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// COM: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : -// COM: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) -// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : -// COM: !torch.tensor, !torch.tensor -> !torch.tuple -// COM: %[[CST0:.*]] = torch.constant.int 0 -// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: %[[CST1:.*]] = torch.constant.int 1 -// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -// func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, -// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { -// %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple -// return %0 : !torch.tuple -// } +// CHECK-LABEL: func.func @call_tuple_return( +// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) +// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : !torch.tensor, !torch.tensor -> !torch.tuple +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, + %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { + %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple + return %0 : !torch.tuple +} From eb5c7d6d21c2e450229f3c70752991fd9d08463b Mon Sep 17 00:00:00 2001 From: AaronStGeorge Date: Thu, 27 Feb 2025 15:10:39 -0800 Subject: [PATCH 2/2] Cleanup --- .../Transforms/AdjustCallingConventions.cpp | 83 +++++++++---------- 1 file changed, 37 insertions(+), 46 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 2e046c7559da..3508f1bc059e 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -166,58 +166,49 @@ class AdjustCallingConventionForReturn LogicalResult matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // The dialect conversion framework inserts unrealized conversion casts to - // materialize legal types from illegal types. For example, for input IR - // like - // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, - // torch.tensor -> !torch.tuple - // return %1 : !torch.tuple - // at this stage in the conversion process we'll have something like - // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, - // !torch.tensor -> !torch.tuple - // %2 = builtin.unrealized_conversion_cast %1 : - // !torch.tuple to !torch.tensor - // %3 = builtin.unrealized_conversion_cast %1 : - // !torch.tuple to !torch.tensor - // return %2, %3 : !torch.tensor, !torch.tensor - // - // Here we map back to the original torch.prim.TupleConstruct's - SmallVector flatOperands; + SmallVector newOperands; for (const auto &vals : adaptor.getOperands()) { - for (const auto &operand : vals) { - // A block argument won't have a defining op. - if (operand.getDefiningOp() && - isa(operand.getDefiningOp())) { - auto definingOp = operand.getDefiningOp()->getOperand(0); - // de-duplicate - if (std::find(flatOperands.begin(), flatOperands.end(), definingOp) == - flatOperands.end()) { - flatOperands.push_back(definingOp); + if (vals.size() == 1) { + if (isa(vals[0].getType())) + continue; + newOperands.push_back(vals[0]); + } else if (vals.size() > 1) { + // The dialect conversion framework inserts unrealized conversion casts + // to materialize legal types from illegal types. For example, for input + // IR like + // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, + // torch.tensor -> !torch.tuple + // return %1 : !torch.tuple + // at this stage in the conversion process we'll have something like + // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, + // !torch.tensor -> !torch.tuple + // %2 = builtin.unrealized_conversion_cast %1 : + // !torch.tuple to !torch.tensor + // %3 = builtin.unrealized_conversion_cast %1 : + // !torch.tuple to !torch.tensor + // return %2, %3 : !torch.tensor, !torch.tensor + // + // Given (%2, %3) as operands, here we map back to the original + // torch.prim.TupleConstruct. + if (vals[0].getDefiningOp() && + isa(vals[0].getDefiningOp())) { + Value operand = vals[0].getDefiningOp()->getOperand(0); + if (auto tuple = dyn_cast(operand.getType())) { + Location loc = op.getLoc(); + for (auto en : llvm::enumerate(tuple.getContainedTypes())) { + auto i = rewriter.create( + loc, rewriter.getI64IntegerAttr(en.index())); + newOperands.push_back(rewriter.create( + loc, en.value(), operand, i)); + } + continue; } - } else { - flatOperands.push_back(operand); } - } - } - SmallVector newOperands; - for (auto operand : flatOperands) { - if (!operand) - continue; - if (isa(operand.getType())) - continue; - if (auto tuple = dyn_cast(operand.getType())) { - Location loc = op.getLoc(); - for (auto en : llvm::enumerate(tuple.getContainedTypes())) { - auto i = rewriter.create( - loc, rewriter.getI64IntegerAttr(en.index())); - newOperands.push_back( - rewriter.create(loc, en.value(), operand, i)); - } - continue; + llvm::append_range(newOperands, vals); } - newOperands.push_back(operand); } + rewriter.replaceOpWithNewOp(op, newOperands); return success(); }