From 401d54dc215aee4ff7bc4b7aad2ce49d7b6d2584 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 14 Apr 2025 13:38:40 -0700 Subject: [PATCH] fold broadcast(splat) -> splat Signed-off-by: James Newling --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 ++ mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bee5c1fd6ed58..4b7ff757c150a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2590,6 +2590,8 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { } if (auto attr = llvm::dyn_cast(adaptor.getSource())) return DenseElementsAttr::get(vectorType, attr.getSplatValue()); + if (llvm::dyn_cast(adaptor.getSource())) + return ub::PoisonAttr::get(getContext()); return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 78b0ea78849e8..420244c5e734a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1151,6 +1151,28 @@ func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) { // ----- +// CHECK-LABEL: broadcast_poison +// CHECK: %[[POISON:.*]] = ub.poison : vector<4x6xi8> +// CHECK: return %[[POISON]] : vector<4x6xi8> +func.func @broadcast_poison() -> vector<4x6xi8> { + %poison = ub.poison : vector<6xi8> + %broadcast = vector.broadcast %poison : vector<6xi8> to vector<4x6xi8> + return %broadcast : vector<4x6xi8> +} + +// ----- + +// CHECK-LABEL: broadcast_splat_constant +// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8> +// CHECK: return %[[CONST]] : vector<4x6xi8> +func.func @broadcast_splat_constant() -> vector<4x6xi8> { + %cst = arith.constant dense<1> : vector<6xi8> + %broadcast = vector.broadcast %cst : vector<6xi8> to vector<4x6xi8> + return %broadcast : vector<4x6xi8> +} + +// ----- + // CHECK-LABEL: broadcast_folding1 // CHECK: %[[CST:.*]] = arith.constant dense<42> : vector<4xi32> // CHECK-NOT: vector.broadcast