From db5d7865008555e8f4104099fde70f24147066c3 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 14 Jan 2025 18:46:46 +0000 Subject: [PATCH 1/2] [mlir][IntRangeInference] Infer values for {memref,tensor}.dim Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the `dim` of dynamic dimensions to [0, index_max] and take the union of all the dimensions that the `dim` argument could be validly referring to. --- mlir/include/mlir/Dialect/MemRef/IR/MemRef.h | 1 + .../mlir/Dialect/MemRef/IR/MemRefOps.td | 6 +- mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 1 + .../mlir/Dialect/Tensor/IR/TensorOps.td | 4 +- .../Interfaces/Utils/InferIntRangeCommon.h | 8 +++ mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 2 + mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 7 +++ mlir/lib/Dialect/Tensor/IR/CMakeLists.txt | 2 + mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 8 +++ mlir/lib/Interfaces/Utils/CMakeLists.txt | 1 + .../Interfaces/Utils/InferIntRangeCommon.cpp | 44 +++++++++++++ .../Dialect/MemRef/int-range-inference.mlir | 61 +++++++++++++++++++ .../Dialect/Tensor/int-range-inference.mlir | 61 +++++++++++++++++++ 13 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Dialect/MemRef/int-range-inference.mlir create mode 100644 mlir/test/Dialect/Tensor/int-range-inference.mlir diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index 72463dca715ca..ac383ab46e7a5 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -17,6 +17,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index a0d8d34f38237..c3ee3968abc16 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/ShapedOpInterfaces.td" @@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [ DeclareOpInterfaceMethods, MemRefsNormalizable, ConditionallySpeculatable, NoMemoryEffect, - ShapedDimOpInterface]> { + ShapedDimOpInterface, + DeclareOpInterfaceMethods]> { let summary = "dimension index operation"; let description = [{ The `dim` operation takes a memref and a dimension operand of type `index`. @@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ }]>, // Builder that infers the result layout map. The result shape must be - // specified. Otherwise, the op may be ambiguous. The output shape for + // specified. Otherwise, the op may be ambiguous. The output shape for // the op will be inferred using the inferOutputShape() method. OpBuilder<(ins "ArrayRef":$resultShape, "Value":$src, "ArrayRef":$reassociation)>, diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h index 0a21c9922b223..bd96337a55407 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -18,6 +18,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 812ac20984502..38874513a4cc0 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/ShapedOpInterfaces.td" @@ -195,7 +196,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat", def Tensor_DimOp : Tensor_Op<"dim", [ DeclareOpInterfaceMethods, ConditionallySpeculatable, NoMemoryEffect, - ShapedDimOpInterface]> { + ShapedDimOpInterface, + DeclareOpInterfaceMethods]> { let summary = "dimension index operation"; let description = [{ The `tensor.dim` operation takes a tensor and a dimension operand of type diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h index 3988a8826498a..e46358ccfc46f 100644 --- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h +++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h @@ -20,6 +20,8 @@ #include namespace mlir { +class ShapedDimOpInterface; + namespace intrange { /// Function that performs inference on an array of `ConstantIntRanges`, /// abstracted away here to permit writing the function that handles both @@ -143,6 +145,12 @@ std::optional evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs); +/// Returns the integer range for the result of a `ShapedDimOpInterface` given +/// the optional inferred ranges for the `dimension` index `maybeDim`. When a +/// dynamic dimension is encountered, returns [0, signed_max(type(result))]. +ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, + const IntegerValueRange &maybeDim); + } // namespace intrange } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 845914ebd107a..734294bd014c6 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRControlFlowInterfaces MLIRDialect MLIRDialectUtils + MLIRInferIntRangeCommon + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 9aae46a5c288d..f0aee7a68e0bf 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() { return Speculation::Speculatable; } +void DimOp::inferResultRangesFromOptional(ArrayRef argRanges, + SetIntLatticeFn setResultRange) { + setResultRange(getResult(), + intrange::inferShapedDimOpInterface(*this, argRanges[1])); +} + /// Return a map with key being elements in `vals` and data being number of /// occurences of it. Use std::map, since the `vals` here are strides and the /// dynamic stride value is the same as the tombstone value for diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt index cfdd3847761a4..5425615dac393 100644 --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRIR + MLIRInferIntRangeCommon + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRParallelCombiningOpInterface MLIRShapedOpInterfaces diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 24a1d55315319..e0853cab60fb9 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -23,7 +23,9 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() { return Speculation::Speculatable; } +void DimOp::inferResultRangesFromOptional(ArrayRef argRanges, + SetIntLatticeFn setResultRange) { + setResultRange(getResult(), + intrange::inferShapedDimOpInterface(*this, argRanges[1])); +} + OpFoldResult DimOp::fold(FoldAdaptor adaptor) { // All forms of folding require a known index. auto index = llvm::dyn_cast_if_present(adaptor.getIndex()); diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt index ece6c8e46ffea..8c45f66997427 100644 --- a/mlir/lib/Interfaces/Utils/CMakeLists.txt +++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon MLIRInferIntRangeInterfaceIncGen LINK_LIBS PUBLIC + MLIRShapedOpInterfaces MLIRInferIntRangeInterface MLIRIR ) diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index 1eab4139488bd..2f47939df5a02 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -14,6 +14,7 @@ #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/ShapedOpInterfaces.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -725,3 +726,46 @@ std::optional mlir::intrange::evaluatePred(CmpPredicate pred, return false; return std::nullopt; } + +//===----------------------------------------------------------------------===// +// Shaped type dimension accessors / ShapedDimOpInterface +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op, + const IntegerValueRange &maybeDim) { + unsigned width = + ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType()); + APInt zero = APInt::getZero(width); + APInt typeMax = APInt::getSignedMaxValue(width); + + auto shapedTy = cast(op.getShapedValue().getType()); + if (!shapedTy.hasRank()) + return ConstantIntRanges::fromSigned(zero, typeMax); + + int64_t rank = shapedTy.getRank(); + int64_t minDim = 0; + int64_t maxDim = rank - 1; + if (!maybeDim.isUninitialized()) { + const ConstantIntRanges &dim = maybeDim.getValue(); + minDim = std::max(minDim, dim.smin().getSExtValue()); + maxDim = std::min(maxDim, dim.smax().getSExtValue()); + } + + std::optional result; + auto joinResult = [&](const ConstantIntRanges &thisResult) { + if (!result.has_value()) + result = thisResult; + else + result = result->rangeUnion(thisResult); + }; + for (int64_t i = minDim; i <= maxDim; ++i) { + int64_t length = shapedTy.getDimSize(i); + + if (ShapedType::isDynamic(length)) + joinResult(ConstantIntRanges::fromSigned(zero, typeMax)); + else + joinResult(ConstantIntRanges::constant(APInt(width, length))); + } + return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax)); +} diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir new file mode 100644 index 0000000000000..e2aa487eaaa25 --- /dev/null +++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s + +// CHECK-LABEL: @dim_const +// CHECK: %[[ret:.+]] = arith.constant 3 : index +// CHECK: return %[[ret]] +func.func @dim_const(%m: memref<3x5xi32>) -> index { + %c0 = arith.constant 0 : index + %0 = memref.dim %m, %c0 : memref<3x5xi32> + return %0 : index +} + +// ----- + +// CHECK-LABEL: @dim_any_static +// CHECK: %[[op:.+]] = memref.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index { + %0 = memref.dim %m, %x : memref<3x5xi32> + %1 = test.reflect_bounds %0 : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dim_dynamic +// CHECK: %[[op:.+]] = memref.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_dynamic(%m: memref) -> index { + %c0 = arith.constant 0 : index + %0 = memref.dim %m, %c0 : memref + %1 = test.reflect_bounds %0 : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dim_any_dynamic +// CHECK: %[[op:.+]] = memref.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_any_dynamic(%m: memref, %x: index) -> index { + %0 = memref.dim %m, %x : memref + %1 = test.reflect_bounds %0 : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dim_some_omitting_dynamic +// CHECK: %[[op:.+]] = memref.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_some_omitting_dynamic(%m: memref, %x: index) -> index { + %c1 = arith.constant 1 : index + %0 = arith.maxsi %x, %c1 : index + %1 = memref.dim %m, %0 : memref + %2 = test.reflect_bounds %1 : index + return %2 : index +} diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir new file mode 100644 index 0000000000000..384ae781e0e33 --- /dev/null +++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s + +// CHECK-LABEL: @dim_const +// CHECK: %[[ret:.+]] = arith.constant 3 : index +// CHECK: return %[[ret]] +func.func @dim_const(%m: tensor<3x5xi32>) -> index { + %c0 = arith.constant 0 : index + %0 = tensor.dim %m, %c0 : tensor<3x5xi32> + return %0 : index +} + +// ----- + +// CHECK-LABEL: @dim_any_static +// CHECK: %[[op:.+]] = tensor.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index { + %0 = tensor.dim %m, %x : tensor<3x5xi32> + %1 = test.reflect_bounds %0 : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dim_dynamic +// CHECK: %[[op:.+]] = tensor.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_dynamic(%m: tensor) -> index { + %c0 = arith.constant 0 : index + %0 = tensor.dim %m, %c0 : tensor + %1 = test.reflect_bounds %0 : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dim_any_dynamic +// CHECK: %[[op:.+]] = tensor.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_any_dynamic(%m: tensor, %x: index) -> index { + %0 = tensor.dim %m, %x : tensor + %1 = test.reflect_bounds %0 : index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dim_some_omitting_dynamic +// CHECK: %[[op:.+]] = tensor.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_some_omitting_dynamic(%m: tensor, %x: index) -> index { + %c1 = arith.constant 1 : index + %0 = arith.maxsi %x, %c1 : index + %1 = tensor.dim %m, %0 : tensor + %2 = test.reflect_bounds %1 : index + return %2 : index +} From fdb82c188c5e41f091dfde0787b83aa8284fa060 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Sat, 18 Jan 2025 03:54:16 +0000 Subject: [PATCH 2/2] Add test for the unranked case --- .../Dialect/MemRef/int-range-inference.mlir | 13 ++++++++ .../Dialect/Tensor/int-range-inference.mlir | 33 +++++++++++++------ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir index e2aa487eaaa25..34568d1d1d520 100644 --- a/mlir/test/Dialect/MemRef/int-range-inference.mlir +++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir @@ -59,3 +59,16 @@ func.func @dim_some_omitting_dynamic(%m: memref, %x: index) -> index %2 = test.reflect_bounds %1 : index return %2 : index } + +// ----- + +// CHECK-LABEL: @dim_unranked +// CHECK: %[[op:.+]] = memref.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_unranked(%m: memref<*xi32>) -> index { + %c0 = arith.constant 0 : index + %0 = memref.dim %m, %c0 : memref<*xi32> + %1 = test.reflect_bounds %0 : index + return %1 : index +} diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir index 384ae781e0e33..e90ebf5fccb8e 100644 --- a/mlir/test/Dialect/Tensor/int-range-inference.mlir +++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: @dim_const // CHECK: %[[ret:.+]] = arith.constant 3 : index // CHECK: return %[[ret]] -func.func @dim_const(%m: tensor<3x5xi32>) -> index { +func.func @dim_const(%t: tensor<3x5xi32>) -> index { %c0 = arith.constant 0 : index - %0 = tensor.dim %m, %c0 : tensor<3x5xi32> + %0 = tensor.dim %t, %c0 : tensor<3x5xi32> return %0 : index } @@ -15,8 +15,8 @@ func.func @dim_const(%m: tensor<3x5xi32>) -> index { // CHECK: %[[op:.+]] = tensor.dim // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]] // CHECK: return %[[ret]] -func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index { - %0 = tensor.dim %m, %x : tensor<3x5xi32> +func.func @dim_any_static(%t: tensor<3x5xi32>, %x: index) -> index { + %0 = tensor.dim %t, %x : tensor<3x5xi32> %1 = test.reflect_bounds %0 : index return %1 : index } @@ -27,9 +27,9 @@ func.func @dim_any_static(%m: tensor<3x5xi32>, %x: index) -> index { // CHECK: %[[op:.+]] = tensor.dim // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] // CHECK: return %[[ret]] -func.func @dim_dynamic(%m: tensor) -> index { +func.func @dim_dynamic(%t: tensor) -> index { %c0 = arith.constant 0 : index - %0 = tensor.dim %m, %c0 : tensor + %0 = tensor.dim %t, %c0 : tensor %1 = test.reflect_bounds %0 : index return %1 : index } @@ -40,8 +40,8 @@ func.func @dim_dynamic(%m: tensor) -> index { // CHECK: %[[op:.+]] = tensor.dim // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] // CHECK: return %[[ret]] -func.func @dim_any_dynamic(%m: tensor, %x: index) -> index { - %0 = tensor.dim %m, %x : tensor +func.func @dim_any_dynamic(%t: tensor, %x: index) -> index { + %0 = tensor.dim %t, %x : tensor %1 = test.reflect_bounds %0 : index return %1 : index } @@ -52,10 +52,23 @@ func.func @dim_any_dynamic(%m: tensor, %x: index) -> index { // CHECK: %[[op:.+]] = tensor.dim // CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]] // CHECK: return %[[ret]] -func.func @dim_some_omitting_dynamic(%m: tensor, %x: index) -> index { +func.func @dim_some_omitting_dynamic(%t: tensor, %x: index) -> index { %c1 = arith.constant 1 : index %0 = arith.maxsi %x, %c1 : index - %1 = tensor.dim %m, %0 : tensor + %1 = tensor.dim %t, %0 : tensor %2 = test.reflect_bounds %1 : index return %2 : index } + +// ----- + +// CHECK-LABEL: @dim_unranked +// CHECK: %[[op:.+]] = tensor.dim +// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]] +// CHECK: return %[[ret]] +func.func @dim_unranked(%t: tensor<*xi32>) -> index { + %c0 = arith.constant 0 : index + %0 = tensor.dim %t, %c0 : tensor<*xi32> + %1 = test.reflect_bounds %0 : index + return %1 : index +}