From 6ca5fa64d35d8f49b2329b536ea565a1d7311848 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 4 Aug 2025 09:25:18 -0700 Subject: [PATCH 1/2] ability to use poison as padding value --- .../Linalg/Transforms/PadTilingInterface.cpp | 18 ++++++++++++------ .../transform-op-pad-tiling-interface.mlir | 4 +++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2e6252336dfeb..3d12bc397813b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" @@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, Value paddingValue; if (auto complexTy = dyn_cast(getElementTypeOrSelf(v.getType()))) { - auto complexAttr = cast(paddingValueAttr); - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), - complexTy, complexAttr); - } else { - paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), - cast(paddingValueAttr)); + if (auto complexAttr = dyn_cast(paddingValueAttr)) { + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + complexTy, complexAttr); + } + } else if (isa(paddingValueAttr)) { + paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), + getElementTypeOrSelf(v.getType())); + } else if (auto typedAttr = dyn_cast(paddingValueAttr)) { + paddingValue = + arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); } + assert(paddingValue && "failed to create value from padding attribute"); // Pad the operand to the bounding box defined by `paddedShape`. SmallVector tensorShape; diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index f7418769f79ca..2857b53103779 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -4,6 +4,7 @@ // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32> { + // %goo = ub.poison : f32 %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32> } @@ -18,7 +19,8 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] { - padding_values=[0.0 : f32, 0.0 : f32] + // padding_values= [poison, 0.0 : f32] + padding_values= [0.0 : f32, 0.0 : f32] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield From 46b37d0b300cdd6d86dd38442a471b3570576a14 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 5 Aug 2025 19:59:22 -0700 Subject: [PATCH 2/2] add test with poison Signed-off-by: James Newling --- .../TransformOps/LinalgTransformOps.cpp | 24 ++++++++++++++----- .../transform-op-pad-tiling-interface.mlir | 14 +++++------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bdfc8d020e58f..4c2686aea0794 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h" @@ -27,6 +28,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Utils/Utils.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -1985,14 +1987,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, // Convert the padding values to attributes. SmallVector paddingValues; - for (auto const &it : + for (auto const &[untypedAttr, elementOrTensorType] : llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { - auto attr = dyn_cast(std::get<0>(it)); + + if (isa(untypedAttr)) { + paddingValues.push_back(untypedAttr); + continue; + } + auto attr = dyn_cast(untypedAttr); if (!attr) { - emitOpError("expects padding values to be typed attributes"); + emitOpError("expects padding values to be typed attributes or poison"); return DiagnosedSilenceableFailure::definiteFailure(); } - Type elementType = getElementTypeOrSelf(std::get<1>(it)); + Type elementType = getElementTypeOrSelf(elementOrTensorType); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = dyn_cast(attr)) { auto parsedAttr = dyn_cast_if_present(parseAttribute( @@ -2000,7 +2007,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); if (!parsedAttr || parsedAttr.getType() != elementType) { auto diag = this->emitOpError("expects a padding that parses to ") - << elementType << ", got " << std::get<0>(it); + << elementType << ", got " << untypedAttr; diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } @@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) { auto attr = dyn_cast(untypedAttr); Type elementType = getElementTypeOrSelf(elementOrTensorType); + + if (isa(untypedAttr)) { + paddingValues.push_back(untypedAttr); + continue; + } if (!attr) { - emitOpError("expects padding values to be typed attributes"); + emitOpError("expects padding values to be typed attributes or poison"); return DiagnosedSilenceableFailure::definiteFailure(); } // Try to parse string attributes to obtain an attribute of element type. diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index 2857b53103779..9a3dcf0b485d5 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -4,7 +4,6 @@ // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32> { - // %goo = ub.poison : f32 %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32> } @@ -15,12 +14,11 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op // Tile to 5 then pad to 8 - %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5] + %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] { - // padding_values= [poison, 0.0 : f32] - padding_values= [0.0 : f32, 0.0 : f32] + padding_values= [#ub.poison, 0.0 : f32] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield @@ -35,9 +33,9 @@ func.func @pad_lhs( -> tensor<24x25xf32> { // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) - // CHECK: tensor.pad %{{.*}} + // CHECK: tensor.pad %{{.*}} // CHECK: : tensor to tensor<8x12xf32> - // CHECK: tensor.pad %{{.*}} + // CHECK: tensor.pad %{{.*}} // CHECK: : tensor to tensor<8x25xf32> // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> // CHECK: tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1] @@ -94,7 +92,7 @@ module { %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] { padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield + transform.yield } } } @@ -149,7 +147,7 @@ module { %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] { padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield + transform.yield } } }