Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4888,11 +4888,11 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
auto elementType = shapedty.getElementType();
if (isa<IntegerType>(elementType)) {
if (isa<mlir::IntegerType>(elementType)) {
Attribute attribute = IntegerAttr::get(elementType, 1);
return DenseElementsAttr::get(shapedty, attribute);
}
if (isa<FloatType>(elementType)) {
if (isa<mlir::FloatType>(elementType)) {
Attribute attribute = FloatAttr::get(elementType, 1.0);
return DenseElementsAttr::get(shapedty, attribute);
}
Expand Down Expand Up @@ -4932,7 +4932,7 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
Attribute attribute = IntegerAttr::get(elementType, 0);
return DenseElementsAttr::get(shapedty, attribute);
}
if (isa<FloatType>(elementType)) {
if (isa<mlir::FloatType>(elementType)) {
Attribute attribute = FloatAttr::get(elementType, 0.0);
return DenseElementsAttr::get(shapedty, attribute);
}
Expand Down Expand Up @@ -4972,7 +4972,7 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(shapedty, attribute);
}
}
if (isa<FloatType>(elementType)) {
if (isa<mlir::FloatType>(elementType)) {
double value = 0.0;
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
Attribute attribute = FloatAttr::get(elementType, value);
Expand Down
14 changes: 6 additions & 8 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,9 @@ func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vt
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32>
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
Expand Down Expand Up @@ -925,10 +924,9 @@ func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !to
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<3x4xi32>}> : () -> tensor<3x4xi32>
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
Expand Down
98 changes: 98 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3410,3 +3410,101 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %3 : !torch.vtensor<[?],f32>
}

// -----

// CHECK-LABEL: func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
// CHECK: }
func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32>
return %1 : !torch.vtensor<[2,3,4],f32>
}

// -----

// CHECK-LABEL: func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> {
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
// CHECK: }
func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64>
return %1 : !torch.vtensor<[2,3,4],si64>
}

// -----

// CHECK-LABEL: func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
// CHECK: }
func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32>
return %1 : !torch.vtensor<[2,3,4],f32>
}

// -----

// CHECK-LABEL: func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> {
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
// CHECK: }
func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64>
return %1 : !torch.vtensor<[2,3,4],si64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0xFF800000> : tensor<2x1x4xf32>) : !torch.vtensor<[2,1,4],f32>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],f32>
// CHECK: }
func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> {
%float-Inf = torch.constant.float 0xFFF0000000000000
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int1, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.full %0, %float-Inf, %none, %none, %none, %none : !torch.list<int>, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],f32>
return %1 : !torch.vtensor<[2,1,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x1x4xsi64>) : !torch.vtensor<[2,1,4],si64>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],si64>
// CHECK: }
func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
%int-Inf = torch.constant.int 0
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int1, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.full %0, %int-Inf, %none, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],si64>
return %1 : !torch.vtensor<[2,1,4],si64>
}
11 changes: 5 additions & 6 deletions test/Dialect/Torch/fuse-quantized-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,24 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.
// CHECK-LABEL: func.func @mm_pad_commute
func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> {
// CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01
// CHECK-DAG: %[[int7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[none:.*]] = torch.constant.none
// CHECK-DAG: %[[padVal:.*]] = torch.vtensor.literal(dense<8.000000e+00> : tensor<f64>) : !torch.vtensor<[],f64>
// CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02
// CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02
// CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00
// CHECK-DAG: %[[str:.*]] = torch.constant.str "constant"
// CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[int0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[int1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[int2:.*]] = torch.constant.int 2
// CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[padVal]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float
// CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8>
// CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8>
// CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
// CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32>
// CHECK: %[[IR:.*]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[9,4],!torch.qint32> -> !torch.vtensor<[9,4],si32>
// CHECK: %[[QOUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[IR]], %[[cstQuart]], %[[int0]] : !torch.vtensor<[9,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[9,4],!torch.qint32>
// CHECK: %[[OUT:.*]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[9,4],!torch.qint32> -> !torch.vtensor<[9,4],f32>
%scale = torch.constant.float 0.5
%false = torch.constant.bool false
%zero = torch.constant.int 0
Expand Down
Loading