From a122eeec7560cd4a0c0a830cbd28d99a7090756e Mon Sep 17 00:00:00 2001 From: Jianjian GUAN Date: Thu, 2 Jan 2025 15:18:34 +0800 Subject: [PATCH 1/5] [mlir][arith] Support bitcast with index type Use kInternalStorageBitWidth as the bit width of index type. --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 13 ++++++++++++- mlir/test/Dialect/Arith/ops.mlir | 6 ++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index e016a6e16e59f..e16eefd32212f 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1747,7 +1747,18 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (!srcType || !dstType) return false; - return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); + unsigned srcWidth, dstWidth; + if (auto indexTy = dyn_cast(srcType)) + srcWidth = IndexType::kInternalStorageBitWidth; + else + srcWidth = srcType.getIntOrFloatBitWidth(); + + if (auto indexTy = dyn_cast(dstType)) + dstWidth = IndexType::kInternalStorageBitWidth; + else + dstWidth = dstType.getIntOrFloatBitWidth(); + + return srcWidth == dstWidth; } OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index f684e02344a51..46cb1993a3b78 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -954,6 +954,12 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x return %0 : vector<[8]xi32> } +// CHECK-LABEL: test_bitcast_index +func.func @test_bitcast_index(%arg0 : i64) -> index { + %0 = arith.bitcast %arg0 : i64 to index + return %0 : index +} + // CHECK-LABEL: test_cmpi func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 { %0 = arith.cmpi ne, %arg0, %arg1 : i64 From 4243d0377d20e94183204de23ba50fe0bf0b591f Mon Sep 17 00:00:00 2001 From: Jianjian GUAN Date: Sat, 4 Jan 2025 15:22:01 +0800 Subject: [PATCH 2/5] disable index type for bitcast --- .../include/mlir/Dialect/Arith/IR/ArithOps.td | 5 ++--- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 19 +++---------------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++ mlir/test/Dialect/Arith/invalid.mlir | 16 ++++++++++++++++ mlir/test/Dialect/Arith/ops.mlir | 6 ------ mlir/test/Dialect/Tensor/invalid.mlir | 16 ++++++++++++++++ 6 files changed, 41 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 0722ff68d890d..80b90f2ae480d 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1392,11 +1392,10 @@ def Arith_IndexCastUIOp // BitcastOp //===----------------------------------------------------------------------===// -// Bitcast can convert between memrefs of signless integers, indices, and -// floats too. +// Bitcast can convert between memrefs of signless integers and floats. def BitcastTypeConstraint : TypeConstraint.predicate]>, + MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>, "signless-integer-or-float-like or memref of signless-integer or float">; def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index e16eefd32212f..7ca104691e6df 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1740,25 +1740,12 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (!areValidCastInputsAndOutputs(inputs, outputs)) return false; - auto srcType = - getTypeIfLikeOrMemRef(inputs.front()); - auto dstType = - getTypeIfLikeOrMemRef(outputs.front()); + auto srcType = getTypeIfLikeOrMemRef(inputs.front()); + auto dstType = getTypeIfLikeOrMemRef(outputs.front()); if (!srcType || !dstType) return false; - unsigned srcWidth, dstWidth; - if (auto indexTy = dyn_cast(srcType)) - srcWidth = IndexType::kInternalStorageBitWidth; - else - srcWidth = srcType.getIntOrFloatBitWidth(); - - if (auto indexTy = dyn_cast(dstType)) - dstWidth = IndexType::kInternalStorageBitWidth; - else - dstWidth = dstType.getIntOrFloatBitWidth(); - - return srcWidth == dstWidth; + return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); } OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 24a1d55315319..9cebb5534ebdd 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -219,6 +219,10 @@ bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (!aT || !bT) return false; + if (isa(aT.getElementType()) || + isa(bT.getElementType())) + return false; + if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth()) return false; diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 088da475e8eb4..54c82f3802ced 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -853,3 +853,19 @@ func.func @select_tensor_encoding( %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo"> return %0 : tensor<8xi32, "foo"> } + +// ----- + +func.func @bitcast_index_0(%arg0 : i64) -> index { + // expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}} + %0 = arith.bitcast %arg0 : i64 to index + return %0 : index +} + +// ----- + +func.func @bitcast_index_1(%arg0 : index) -> i64 { + // expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}} + %0 = arith.bitcast %arg0 : index to i64 + return %0 : i64 +} diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 46cb1993a3b78..f684e02344a51 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -954,12 +954,6 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x return %0 : vector<[8]xi32> } -// CHECK-LABEL: test_bitcast_index -func.func @test_bitcast_index(%arg0 : i64) -> index { - %0 = arith.bitcast %arg0 : i64 to index - return %0 : index -} - // CHECK-LABEL: test_cmpi func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 { %0 = arith.cmpi ne, %arg0, %arg1 : i64 diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 1de3e281bc462..23c1f5360d361 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -807,3 +807,19 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape( %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor -> tensor return %0 : tensor } + +// ----- + +func.func @bitcast_index_0(%arg0 : tensor) -> tensor { + // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor' and result type 'tensor' are cast incompatible}} + %0 = tensor.bitcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + +func.func @bitcast_index_1(%arg0 : tensor) -> tensor { + // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor' and result type 'tensor' are cast incompatible}} + %0 = tensor.bitcast %arg0 : tensor to tensor + return %0 : tensor +} From 31a53225b96dac17363d6156a1f33f115fa1672e Mon Sep 17 00:00:00 2001 From: Jianjian GUAN Date: Tue, 7 Jan 2025 14:18:08 +0800 Subject: [PATCH 3/5] change type constraint of bitcast --- mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 2 +- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 6 ++++-- mlir/include/mlir/IR/CommonTypeConstraints.td | 5 +++++ mlir/test/Dialect/Arith/invalid.mlir | 4 ++-- mlir/test/Dialect/Tensor/invalid.mlir | 4 ++-- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 80b90f2ae480d..10d7519e09dbe 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1394,7 +1394,7 @@ def Arith_IndexCastUIOp // Bitcast can convert between memrefs of signless integers and floats. def BitcastTypeConstraint : TypeConstraint.predicate]>, "signless-integer-or-float-like or memref of signless-integer or float">; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 812ac20984502..8ad1b23cb2bfe 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -75,8 +75,10 @@ def Tensor_BitcastOp : Tensor_Op<"bitcast", [ ``` }]; - let arguments = (ins AnyTensor:$source); - let results = (outs AnyTensor:$dest); + let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger, + AnySignedInteger, AnyFloat]>:$source); + let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger, + AnySignedInteger, AnyFloat]>:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index e592910303568..38bff642630fe 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -908,6 +908,11 @@ def BoolLike : TypeOrContainer; def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank; +// Type constraint for signless-integer-like types: signless integers, +// vectors of signless integers or tensors of signless integers. +def SignlessInteger : TypeOrValueSemanticsContainer< + AnySignlessInteger, "signless-integer">; + // Type constraint for signless-integer-like types: signless integers, indices, // vectors of signless integers or indices, tensors of signless integers. def SignlessIntegerLike : TypeOrValueSemanticsContainer< diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 54c82f3802ced..7bd68372de471 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -857,7 +857,7 @@ func.func @select_tensor_encoding( // ----- func.func @bitcast_index_0(%arg0 : i64) -> index { - // expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}} + // expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}} %0 = arith.bitcast %arg0 : i64 to index return %0 : index } @@ -865,7 +865,7 @@ func.func @bitcast_index_0(%arg0 : i64) -> index { // ----- func.func @bitcast_index_1(%arg0 : index) -> i64 { - // expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}} + // expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}} %0 = arith.bitcast %arg0 : index to i64 return %0 : i64 } diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 23c1f5360d361..0c6d8f4e05c33 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -811,7 +811,7 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape( // ----- func.func @bitcast_index_0(%arg0 : tensor) -> tensor { - // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor' and result type 'tensor' are cast incompatible}} + // expected-error @+1 {{'tensor.bitcast' op result #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor'}} %0 = tensor.bitcast %arg0 : tensor to tensor return %0 : tensor } @@ -819,7 +819,7 @@ func.func @bitcast_index_0(%arg0 : tensor) -> tensor { // ----- func.func @bitcast_index_1(%arg0 : tensor) -> tensor { - // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor' and result type 'tensor' are cast incompatible}} + // expected-error @+1 {{'tensor.bitcast' op operand #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor'}} %0 = tensor.bitcast %arg0 : tensor to tensor return %0 : tensor } From 20981e4668cd09f12476706e7b644d29245aecea Mon Sep 17 00:00:00 2001 From: Jianjian GUAN Date: Tue, 7 Jan 2025 17:49:45 +0800 Subject: [PATCH 4/5] Address rename comment --- .../include/mlir/Dialect/Arith/IR/ArithOps.td | 30 +++++++++---------- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 10 +++---- mlir/include/mlir/IR/CommonTypeConstraints.td | 13 +++++--- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 --- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 10d7519e09dbe..ea9b0f6509b80 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -51,8 +51,8 @@ class Arith_BinaryOp traits = []> : class Arith_IntBinaryOp traits = []> : Arith_BinaryOp]>, - Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>, - Results<(outs SignlessIntegerLike:$result)>; + Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>, + Results<(outs SignlessIntegerOrIndexLike:$result)>; // Base class for integer binary operations without undefined behavior. class Arith_TotalIntBinaryOp traits = []> : @@ -155,11 +155,11 @@ class Arith_IntBinaryOpWithOverflowFlags traits = [ Arith_BinaryOp, DeclareOpInterfaceMethods]>, - Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs, + Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs, DefaultValuedAttr< Arith_IntegerOverflowAttr, "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>, - Results<(outs SignlessIntegerLike:$result)> { + Results<(outs SignlessIntegerOrIndexLike:$result)> { let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)? attr-dict `:` type($result) }]; @@ -198,7 +198,7 @@ def Arith_ConstantOp : Op { // Index cast can convert between memrefs of signless integers and indices too. def IndexCastTypeConstraint : TypeConstraint.predicate]>, "signless-integer-like or memref of signless-integer">; @@ -1394,7 +1394,7 @@ def Arith_IndexCastUIOp // Bitcast can convert between memrefs of signless integers and floats. def BitcastTypeConstraint : TypeConstraint.predicate]>, "signless-integer-or-float-like or memref of signless-integer or float">; @@ -1495,8 +1495,8 @@ def Arith_CmpIOp }]; let arguments = (ins Arith_CmpIPredicateAttr:$predicate, - SignlessIntegerLikeOfAnyRank:$lhs, - SignlessIntegerLikeOfAnyRank:$rhs); + SignlessIntegerOrIndexLikeOfAnyRank:$lhs, + SignlessIntegerOrIndexLikeOfAnyRank:$rhs); let hasFolder = 1; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index 3f6d2d2e44783..5990a9f0d2e44 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -28,8 +28,8 @@ class Math_Op traits = []> : // tensor thereof. class Math_IntegerUnaryOp traits = []> : Math_Op { - let arguments = (ins SignlessIntegerLike:$operand); - let results = (outs SignlessIntegerLike:$result); + let arguments = (ins SignlessIntegerOrIndexLike:$operand); + let results = (outs SignlessIntegerOrIndexLike:$result); let assemblyFormat = "$operand attr-dict `:` type($result)"; } @@ -55,8 +55,8 @@ class Math_FloatUnaryOp traits = []> : // type, vector or tensor thereof. class Math_IntegerBinaryOp traits = []> : Math_Op { - let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs); - let results = (outs SignlessIntegerLike:$result); + let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs); + let results = (outs SignlessIntegerOrIndexLike:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; } @@ -976,7 +976,7 @@ def Math_FPowIOp : Math_Op<"fpowi", ``` }]; - let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs, + let arguments = (ins FloatLike:$lhs, SignlessIntegerOrIndexLike:$rhs, DefaultValuedAttr:$fastmath); let results = (outs FloatLike:$result); diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 38bff642630fe..82e335e30b6fa 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -910,24 +910,29 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank; // Type constraint for signless-integer-like types: signless integers, // vectors of signless integers or tensors of signless integers. -def SignlessInteger : TypeOrValueSemanticsContainer< +def SignlessIntegerLike : TypeOrValueSemanticsContainer< AnySignlessInteger, "signless-integer">; // Type constraint for signless-integer-like types: signless integers, indices, // vectors of signless integers or indices, tensors of signless integers. -def SignlessIntegerLike : TypeOrValueSemanticsContainer< +def SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer< AnySignlessIntegerOrIndex, "signless-integer-like">; -def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank< +def SignlessIntegerOrIndexLikeOfAnyRank : TypeOrContainerOfAnyRank< AnySignlessIntegerOrIndex, "signless-integer-like">; // Type constraint for float-like types: floats, vectors or tensors thereof. def FloatLike : TypeOrContainer; -// Type constraint for signless-integer-like or float-like types. +// Type constraint for signless-integer-or-index-like or float-like types. def SignlessIntegerOrFloatLike : TypeConstraint, "signless-integer-like or floating-point-like">; +// Type constraint for signless-integer-or-index-like or float-like types. +def SignlessIntegerOrIndexOrFloatLike : TypeConstraint, + "signless-integer-or-index-like or floating-point-like">; + #endif // COMMON_TYPE_CONSTRAINTS_TD diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 9cebb5534ebdd..24a1d55315319 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -219,10 +219,6 @@ bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (!aT || !bT) return false; - if (isa(aT.getElementType()) || - isa(bT.getElementType())) - return false; - if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth()) return false; From 5767f51d164056c2cd675e3d865d21b1e3e3029f Mon Sep 17 00:00:00 2001 From: Jianjian GUAN Date: Fri, 24 Jan 2025 16:10:42 +0800 Subject: [PATCH 5/5] Fix flang error --- flang/include/flang/Optimizer/Dialect/FIRTypes.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td index 6ae74f16a72d3..41e765c1cb7b9 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -579,7 +579,7 @@ def IsBaseBoxTypePred def fir_BaseBoxType : Type; // Generalized FIR and standard dialect types representing intrinsic types -def AnyIntegerLike : TypeConstraint, "any integer">; def AnyLogicalLike : TypeConstraint