diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bf2a605c950b..8900b7c12d64 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4888,11 +4888,11 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { return nullptr; } auto elementType = shapedty.getElementType(); - if (isa(elementType)) { + if (isa(elementType)) { Attribute attribute = IntegerAttr::get(elementType, 1); return DenseElementsAttr::get(shapedty, attribute); } - if (isa(elementType)) { + if (isa(elementType)) { Attribute attribute = FloatAttr::get(elementType, 1.0); return DenseElementsAttr::get(shapedty, attribute); } @@ -4932,7 +4932,7 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { Attribute attribute = IntegerAttr::get(elementType, 0); return DenseElementsAttr::get(shapedty, attribute); } - if (isa(elementType)) { + if (isa(elementType)) { Attribute attribute = FloatAttr::get(elementType, 0.0); return DenseElementsAttr::get(shapedty, attribute); } @@ -4972,7 +4972,7 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { return DenseElementsAttr::get(shapedty, attribute); } } - if (isa(elementType)) { + if (isa(elementType)) { double value = 0.0; if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) { Attribute attribute = FloatAttr::get(elementType, value); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index cb1a69e6a622..d100fe9dcfde 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -857,10 +857,9 @@ func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vt // CHECK: %[[VAL_1:.*]] = torch.constant.int 3 // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<3x4xf32>}> : () -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { %int4 = torch.constant.int 4 @@ -925,10 +924,9 @@ func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !to // CHECK: %[[VAL_1:.*]] = torch.constant.int 3 // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<3x4xf32>}> : () -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { %int4 = torch.constant.int 4 diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 34fd1b886e26..aebbaff6f2e5 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3410,3 +3410,101 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1 torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %3 : !torch.vtensor<[?],f32> } + +// ----- + +// CHECK-LABEL: func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> { +// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> +// CHECK: } +func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + return %1 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> { +// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64> +// CHECK: } +func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64> + return %1 : !torch.vtensor<[2,3,4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> { +// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> +// CHECK: } +func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + return %1 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> { +// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64> +// CHECK: } +func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64> + return %1 : !torch.vtensor<[2,3,4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> { +// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0xFF800000> : tensor<2x1x4xf32>) : !torch.vtensor<[2,1,4],f32> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],f32> +// CHECK: } +func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> { + %float-Inf = torch.constant.float 0xFFF0000000000000 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int2, %int1, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.full %0, %float-Inf, %none, %none, %none, %none : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],f32> + return %1 : !torch.vtensor<[2,1,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> { +// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x1x4xsi64>) : !torch.vtensor<[2,1,4],si64> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],si64> +// CHECK: } +func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> { + %int-Inf = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int2, %int1, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.full %0, %int-Inf, %none, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],si64> + return %1 : !torch.vtensor<[2,1,4],si64> +} diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index cb39cbd53ece..be2cb7565ca4 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -85,25 +85,24 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch. // CHECK-LABEL: func.func @mm_pad_commute func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> { // CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01 - // CHECK-DAG: %[[int7:.*]] = torch.constant.int 7 - // CHECK-DAG: %[[none:.*]] = torch.constant.none + // CHECK-DAG: %[[padVal:.*]] = torch.vtensor.literal(dense<8.000000e+00> : tensor) : !torch.vtensor<[],f64> // CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02 // CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02 - // CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00 // CHECK-DAG: %[[str:.*]] = torch.constant.str "constant" // CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01 // CHECK-DAG: %[[int0:.*]] = torch.constant.int 0 // CHECK-DAG: %[[int1:.*]] = torch.constant.int 1 // CHECK-DAG: %[[int2:.*]] = torch.constant.int 2 // CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list - // CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64> - // CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64> + // CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[padVal]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64> // CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float // CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8> // CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8> // CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8> // CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32> + // CHECK: %[[IR:.*]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[9,4],!torch.qint32> -> !torch.vtensor<[9,4],si32> + // CHECK: %[[QOUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[IR]], %[[cstQuart]], %[[int0]] : !torch.vtensor<[9,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[9,4],!torch.qint32> + // CHECK: %[[OUT:.*]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[9,4],!torch.qint32> -> !torch.vtensor<[9,4],f32> %scale = torch.constant.float 0.5 %false = torch.constant.bool false %zero = torch.constant.int 0