diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir index bfaadba093..d7f8a18c5d 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir @@ -545,15 +545,12 @@ func.func @dynamic_iota_broadcast_dim0_i64(%arg0 : tensor<2xi64>) -> tensor<5x?x func.return %0 : tensor<5x?xi32> } +// Index-typed shapes are skipped (run shape-legalize-to-stablehlo first). // CHECK-LABEL: @dynamic_iota_broadcast_dim1_index func.func @dynamic_iota_broadcast_dim1_index(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { - // CHECK-NEXT: [[CAST:%.+]] = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi64> - // CHECK-NEXT: [[SLICE:%.+]] = stablehlo.slice [[CAST]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> - // CHECK-NEXT: [[IOTA:%.+]] = stablehlo.dynamic_iota [[SLICE]], dim = 0 : (tensor<1xi64>) -> tensor - // CHECK-NEXT: [[BROADCAST:%.+]] = stablehlo.dynamic_broadcast_in_dim [[IOTA]], %arg0, dims = [1] : (tensor, tensor<2xindex>) -> tensor<5x?xi32> + // CHECK-NEXT: stablehlo.dynamic_iota %arg0 %0 = "stablehlo.dynamic_iota"(%arg0) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<5x?xi32> - - // CHECK: return [[BROADCAST]] + // CHECK-NEXT: return func.return %0 : tensor<5x?xi32> } diff --git a/stablehlo/transforms/optimization/Passes.td b/stablehlo/transforms/optimization/Passes.td index 2e6c61b80c..0211da6d47 100644 --- a/stablehlo/transforms/optimization/Passes.td +++ b/stablehlo/transforms/optimization/Passes.td @@ -176,7 +176,6 @@ def StablehloAggressiveSimplificationPass }]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", - "mlir::arith::ArithDialect", ]; let options = EnableStablehloOptimizationPassFlags<[ "fold-op-element-limit", diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp index b676a8eb1e..65f198e64f 100644 --- a/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp @@ -24,7 +24,6 @@ #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" @@ -506,12 +505,16 @@ struct DynamicIotaOpToBroadcast Value iotaShape = iota.getOutputShape(); auto iotaShapeType = cast(iotaShape.getType()); + if (iotaShapeType.getElementType().isIndex()) + return rewriter.notifyMatchFailure( + iota, "index-typed shapes not supported; run " + "shape-legalize-to-stablehlo first"); + auto iotaShapeI64Type = RankedTensorType::get(iotaShapeType.getShape(), rewriter.getI64Type()); Value iotaShapeI64; - if (iotaShapeType.getElementType().isIndex()) { - iotaShapeI64 = arith::IndexCastOp::create(rewriter, iotaLoc, - iotaShapeI64Type, iotaShape); + if (iotaShapeType.getElementType().isInteger(64)) { + iotaShapeI64 = iotaShape; } else { iotaShapeI64 = stablehlo::ConvertOp::create(rewriter, iotaLoc, iotaShapeI64Type, iotaShape);