Skip to content

Commit f94b072

Browse files
[TOSA] Handle float<->bool cast via i8 in tosaCastTensorToType (#4257)
1 parent 46925eb commit f94b072

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,30 @@ std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
434434
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
435435
// return std::nullopt;
436436

437+
if (llvm::isa<FloatType>(srcElemTy) && destElemTy.isInteger(1)) {
438+
// TOSA does not support casting from float->i1.
439+
// In PyTorch the bool value will be True if any element is non-zero
440+
Value zeroValue = *getConstTensor<float>(rewriter, op, 0.0f, {}, srcElemTy);
441+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
442+
.failed())
443+
return std::nullopt;
444+
445+
auto cmpTy = srcType.clone(rewriter.getIntegerType(1));
446+
Value isEq =
447+
rewriter.create<tosa::EqualOp>(op->getLoc(), cmpTy, src, zeroValue);
448+
return rewriter.create<tosa::LogicalNotOp>(op->getLoc(),
449+
srcType.clone(destElemTy), isEq);
450+
}
451+
452+
if (srcElemTy.isInteger(1) && llvm::isa<FloatType>(destElemTy)) {
453+
// TOSA does not support casting from i1->float.
454+
// Instead, we cast to i8 and then to the float.
455+
TensorType midType = srcType.clone(rewriter.getIntegerType(8));
456+
Value mid = rewriter.create<tosa::CastOp>(op->getLoc(), midType, src);
457+
return rewriter.create<tosa::CastOp>(op->getLoc(),
458+
srcType.clone(destElemTy), mid);
459+
}
460+
437461
if (srcElemTy == destElemTy)
438462
return src;
439463

projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ def TypeConversionF32ToF64Module_basic(module, tu: TestUtils):
2727
module.forward(tu.rand(3, 5))
2828

2929

30+
class TypeConversionF32ToI1Module(torch.nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
34+
@export
35+
@annotate_args([None, ([-1, -1], torch.float32, True)])
36+
def forward(self, x):
37+
return x.to(torch.bool)
38+
39+
40+
@register_test_case(module_factory=lambda: TypeConversionF32ToI1Module())
41+
def TypeConversionF32ToI1Module_basic(module, tu: TestUtils):
42+
module.forward(tu.rand(3, 5))
43+
44+
3045
class TypeConversionF64ToF32Module(torch.nn.Module):
3146
def __init__(self):
3247
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,50 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !
11691169
return %0 : !torch.vtensor<[3,5],si64>
11701170
}
11711171

1172+
// -----
1173+
// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToBool(
1174+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> {
1175+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
1176+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 11
1177+
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
1178+
// CHECK: %[[VAL_4:.*]] = torch.constant.none
1179+
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
1180+
// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
1181+
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
1182+
// CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_1]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xi1>
1183+
// CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<3x5xi1>) -> tensor<3x5xi1>
1184+
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
1185+
// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,5],i1>
1186+
// CHECK: }
1187+
func.func @torch.aten.to.dtype$floatToBool(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> {
1188+
%int11 = torch.constant.int 11
1189+
%false = torch.constant.bool false
1190+
%none = torch.constant.none
1191+
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1>
1192+
return %0 : !torch.vtensor<[3,5],i1>
1193+
}
1194+
1195+
// -----
1196+
// CHECK-LABEL: func.func @torch.aten.to.dtype$boolToFloat(
1197+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> {
1198+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],i1> -> tensor<3x4xi1>
1199+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 6
1200+
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
1201+
// CHECK: %[[VAL_4:.*]] = torch.constant.none
1202+
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi1>) -> tensor<3x4xi8>
1203+
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x4xi8>) -> tensor<3x4xf32>
1204+
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
1205+
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32>
1206+
// CHECK: }
1207+
func.func @torch.aten.to.dtype$boolToFloat(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> {
1208+
%int6 = torch.constant.int 6
1209+
%false = torch.constant.bool false
1210+
%none = torch.constant.none
1211+
%0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[3,4],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
1212+
return %0 : !torch.vtensor<[3,4],f32>
1213+
}
1214+
1215+
11721216
// -----
11731217
// CHECK-LABEL: func.func @torch.aten.gather(
11741218
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,

0 commit comments

Comments
 (0)