@@ -1169,6 +1169,50 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !
1169
1169
return %0 : !torch.vtensor <[3 ,5 ],si64 >
1170
1170
}
1171
1171
1172
+ // -----
1173
+ // CHECK-LABEL: func.func @torch.aten.to.dtype$floatToBool(
1174
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> {
1175
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
1176
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 11
1177
+ // CHECK: %[[VAL_3:.*]] = torch.constant.bool false
1178
+ // CHECK: %[[VAL_4:.*]] = torch.constant.none
1179
+ // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
1180
+ // CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
1181
+ // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
1182
+ // CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_1]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xi1>
1183
+ // CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<3x5xi1>) -> tensor<3x5xi1>
1184
+ // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
1185
+ // CHECK: return %[[VAL_10]] : !torch.vtensor<[3,5],i1>
1186
+ // CHECK: }
1187
+ func.func @torch.aten.to.dtype$floatToBool (%arg0: !torch.vtensor <[3 ,5 ],f32 >) -> !torch.vtensor <[3 ,5 ],i1 > {
1188
+ %int11 = torch.constant.int 11
1189
+ %false = torch.constant.bool false
1190
+ %none = torch.constant.none
1191
+ %0 = torch.aten.to.dtype %arg0 , %int11 , %false , %false , %none : !torch.vtensor <[3 ,5 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[3 ,5 ],i1 >
1192
+ return %0 : !torch.vtensor <[3 ,5 ],i1 >
1193
+ }
1194
+
1195
+ // -----
1196
+ // CHECK-LABEL: func.func @torch.aten.to.dtype$boolToFloat(
1197
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> {
1198
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],i1> -> tensor<3x4xi1>
1199
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 6
1200
+ // CHECK: %[[VAL_3:.*]] = torch.constant.bool false
1201
+ // CHECK: %[[VAL_4:.*]] = torch.constant.none
1202
+ // CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi1>) -> tensor<3x4xi8>
1203
+ // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x4xi8>) -> tensor<3x4xf32>
1204
+ // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
1205
+ // CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32>
1206
+ // CHECK: }
1207
+ func.func @torch.aten.to.dtype$boolToFloat (%arg0: !torch.vtensor <[3 ,4 ],i1 >) -> !torch.vtensor <[3 ,4 ],f32 > {
1208
+ %int6 = torch.constant.int 6
1209
+ %false = torch.constant.bool false
1210
+ %none = torch.constant.none
1211
+ %0 = torch.aten.to.dtype %arg0 , %int6 , %false , %false , %none : !torch.vtensor <[3 ,4 ],i1 >, !torch.int , !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[3 ,4 ],f32 >
1212
+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
1213
+ }
1214
+
1215
+
1172
1216
// -----
1173
1217
// CHECK-LABEL: func.func @torch.aten.gather(
1174
1218
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
0 commit comments