Skip to content

Commit 6daa20e

Browse files
authored
Use op.dtype to create aten.empty.memory_format during decomposition. (#3941)
Prior to the change in this PR `torch-mlir-opt --convert-torch-to-linalg` was running into the following error: ``` error: 'tensor.cast' op operand type 'tensor<200x200x26xf32>' and result type 'tensor<200x200x26xf64>' are cast incompatible %1 = torch.aten.empty.memory_format %0, %none, %none, %none, %false, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64> ^ note: see current operation: %12 = "tensor.cast"(%11) : (tensor<200x200x26xf32>) -> tensor<200x200x26xf64 ``` This is because when `dtype` of the `aten.empty.memory_format` is `none`, by default `f32` was being selected as the element type of the resulting tensor which doesn't match with the actual element type of the result.
1 parent ae310b4 commit 6daa20e

File tree

5 files changed

+139
-20
lines changed

5 files changed

+139
-20
lines changed

include/torch-mlir/Dialect/Torch/Utils/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Type getTypeForTorchType(
3636
MLIRContext *context, Type type,
3737
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
3838

39+
template <typename OpTy>
40+
FailureOr<Value> getDtypeFromOp(PatternRewriter &rewriter, OpTy op);
41+
3942
FailureOr<Type> getTorchTypeForScalarType(MLIRContext *context,
4043
torch_upstream::ScalarType dtypeInt);
4144

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7087,9 +7087,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
70877087
Torch::ListType::get(Torch::IntType::get(op.getContext()));
70887088
Value sizeList =
70897089
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
7090+
7091+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
7092+
if (failed(dtype)) {
7093+
return rewriter.notifyMatchFailure(
7094+
op, "could not determine dtype from the op.");
7095+
}
7096+
70907097
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
7091-
op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
7092-
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
7098+
op, op.getType(), sizeList, *dtype, op.getLayout(), op.getDevice(),
7099+
op.getPinMemory(), op.getMemoryFormat());
70937100
return success();
70947101
}
70957102
};
@@ -7838,18 +7845,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
78387845
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
78397846
PatternRewriter &rewriter) const override {
78407847
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
7841-
Value dtype = op.getDtype();
7842-
if (isa<Torch::NoneType>(dtype.getType())) {
7843-
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
7844-
if (!tensorType.hasDtype()) {
7845-
return rewriter.notifyMatchFailure(
7846-
op, "expected input tensor to have a dtype");
7847-
}
7848-
dtype =
7849-
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
7848+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
7849+
if (failed(dtype)) {
7850+
return rewriter.notifyMatchFailure(
7851+
op, "could not determine dtype from the op.");
78507852
}
78517853
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
7852-
op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(),
7854+
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
78537855
op.getPinMemory(), /*memoryFormat=*/noneVal);
78547856
return success();
78557857
}
@@ -9257,12 +9259,12 @@ class DecomposeAtenRandnGeneratorOp
92579259
Location loc = op.getLoc();
92589260
auto resultType = cast<BaseTensorType>(op.getType());
92599261

9260-
if (!resultType.hasDtype()) {
9262+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
9263+
if (failed(dtype)) {
92619264
return rewriter.notifyMatchFailure(
9262-
op, "expected result type to have a dtype");
9265+
op, "could not determine dtype from the op.");
92639266
}
92649267

9265-
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
92669268
Value none = rewriter.create<ConstantNoneOp>(loc);
92679269
Value low = rewriter.create<Torch::ConstantFloatOp>(
92689270
loc, rewriter.getF64FloatAttr((double)0.0));
@@ -9274,12 +9276,12 @@ class DecomposeAtenRandnGeneratorOp
92749276
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));
92759277

92769278
Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
9277-
loc, resultType, op.getSize(), /*dtype=*/dtype,
9279+
loc, resultType, op.getSize(), /*dtype=*/*dtype,
92789280
/*layout=*/op.getLayout(),
92799281
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
92809282
/*memory_format=*/none);
92819283
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
9282-
loc, resultType, op.getSize(), /*dtype=*/dtype,
9284+
loc, resultType, op.getSize(), /*dtype=*/*dtype,
92839285
/*layout=*/op.getLayout(),
92849286
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
92859287
/*memory_format=*/none);
@@ -9377,8 +9379,13 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
93779379
loc, rewriter.getF64FloatAttr((double)0.0));
93789380
Value high = rewriter.create<Torch::ConstantFloatOp>(
93799381
loc, rewriter.getF64FloatAttr((double)1.0));
9382+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
9383+
if (failed(dtype)) {
9384+
return rewriter.notifyMatchFailure(
9385+
op, "could not determine dtype from the op.");
9386+
}
93809387
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
9381-
loc, resultType, op.getSize(), /*dtype=*/op.getDtype(),
9388+
loc, resultType, op.getSize(), /*dtype=*/*dtype,
93829389
/*layout=*/op.getLayout(),
93839390
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
93849391
/*memory_format=*/noneVal);
@@ -9536,9 +9543,14 @@ class DecomposeAtenEmptyStridedOp
95369543

95379544
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
95389545

9546+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
9547+
if (failed(dtype)) {
9548+
return rewriter.notifyMatchFailure(
9549+
op, "could not determine dtype from the op.");
9550+
}
95399551
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
9540-
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
9541-
op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
9552+
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
9553+
op.getPinMemory(), /*memoryFormat=*/noneVal);
95429554
return success();
95439555
}
95449556
};

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,42 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
237237
rewriter.getI64IntegerAttr(intType));
238238
}
239239

240+
template <typename OpTy>
241+
FailureOr<Value> Torch::getDtypeFromOp(PatternRewriter &rewriter, OpTy op) {
242+
// For ops like AtenEmptyLikeOp, if dtype specified in the op is none, then it
243+
// defaults to dtype of input. Since dtype specifies the dtype of output, in
244+
// this scenario we can look at dtype of output instead of input itself.
245+
// For ops like AtenRandOp, if dtype specified in the op is none, then it
246+
// defaults to a global value. In this case as well we can look at dtype of
247+
// output as it will already be set according to the default global value.
248+
Value dtype = op.getDtype();
249+
if (isa<Torch::NoneType>(dtype.getType())) {
250+
BaseTensorType tensorType = cast<BaseTensorType>(op.getType());
251+
if (!tensorType.hasDtype()) {
252+
return rewriter.notifyMatchFailure(
253+
op, "expected input tensor to have a dtype");
254+
}
255+
dtype =
256+
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
257+
}
258+
return dtype;
259+
}
260+
// Template instantiation template std::optional<Value>
261+
template FailureOr<Value>
262+
Torch::getDtypeFromOp<AtenEmptyLikeOp>(PatternRewriter &rewriter,
263+
AtenEmptyLikeOp op);
264+
template FailureOr<Value>
265+
Torch::getDtypeFromOp<AtenNewEmptyOp>(PatternRewriter &rewriter,
266+
AtenNewEmptyOp op);
267+
template FailureOr<Value>
268+
Torch::getDtypeFromOp<AtenRandOp>(PatternRewriter &rewriter, AtenRandOp op);
269+
template FailureOr<Value>
270+
Torch::getDtypeFromOp<AtenEmptyStridedOp>(PatternRewriter &rewriter,
271+
AtenEmptyStridedOp op);
272+
template FailureOr<Value>
273+
Torch::getDtypeFromOp<AtenRandnGeneratorOp>(PatternRewriter &rewriter,
274+
AtenRandnGeneratorOp op);
275+
240276
// Helper to convert a tensor to a specific scalar type.
241277
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
242278
Value input, Type dtype) {

projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,26 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
641641
module.forward(tu.rand(2, 3, 4))
642642

643643

644+
class EmptyLikeDefaultDtypeFloat64InputModule(torch.nn.Module):
645+
def __init__(self):
646+
super().__init__()
647+
648+
@export
649+
@annotate_args(
650+
[
651+
None,
652+
([-1, -1, -1], torch.float64, True),
653+
]
654+
)
655+
def forward(self, x):
656+
return torch.empty_like(x).fill_(0)
657+
658+
659+
@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeFloat64InputModule())
660+
def EmptyLikeDefaultDtypeFloat64InputModule_basic(module, tu: TestUtils):
661+
module.forward(torch.ones((200, 200, 26), dtype=torch.float64))
662+
663+
644664
# ==============================================================================
645665

646666

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,51 @@ func.func @convolution_backward_none_result(%arg0: !torch.vtensor<[1,1,3,3],f32>
312312
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
313313
return %result1, %result2 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
314314
}
315+
316+
// -----
317+
// CHECK-LABEL: func.func @emptyLikeNoneDtype(
318+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
319+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
320+
// CHECK: %[[NONE:.*]] = torch.constant.none
321+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
322+
// CHECK: %[[C200:.*]] = torch.constant.int 200
323+
// CHECK: %[[C26:.*]] = torch.constant.int 26
324+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
325+
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
326+
func.func @emptyLikeNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
327+
%none = torch.constant.none
328+
%none_0 = torch.constant.none
329+
%none_1 = torch.constant.none
330+
%false = torch.constant.bool false
331+
%none_2 = torch.constant.none
332+
%0 = torch.aten.empty_like %arg0, %none, %none_0, %none_1, %false, %none_2 : !torch.vtensor<[200,200,26],f64>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
333+
return %0 : !torch.vtensor<[200,200,26],f64>
334+
}
335+
336+
// -----
337+
// CHECK-LABEL: func.func @randNoneDtype(
338+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
339+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
340+
// CHECK: %[[C1:.*]] = torch.constant.float 1.000000e+00
341+
// CHECK: %[[C0:.*]] = torch.constant.float 0.000000e+00
342+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
343+
// CHECK: %[[NONE:.*]] = torch.constant.none
344+
// CHECK: %[[C200:.*]] = torch.constant.int 200
345+
// CHECK: %[[C26:.*]] = torch.constant.int 26
346+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
347+
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
348+
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[CPU]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
349+
// CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[MEM_FMT]], %[[C0]], %[[C1]], %[[NONE]] : !torch.vtensor<[200,200,26],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[200,200,26],f64>
350+
// CHECK: return %[[UNIFORM]] : !torch.vtensor<[200,200,26],f64>
351+
func.func @randNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
352+
%int200 = torch.constant.int 200
353+
%int200_0 = torch.constant.int 200
354+
%int26 = torch.constant.int 26
355+
%0 = torch.prim.ListConstruct %int200, %int200_0, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
356+
%none = torch.constant.none
357+
%none_1 = torch.constant.none
358+
%cpu = torch.constant.device "cpu"
359+
%false = torch.constant.bool false
360+
%1 = torch.aten.rand %0, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[200,200,26],f64>
361+
return %1 : !torch.vtensor<[200,200,26],f64>
362+
}

0 commit comments

Comments
 (0)