Skip to content

Commit 46b37d0

Browse files
committed
add test with poison
Signed-off-by: James Newling <[email protected]>
1 parent 6ca5fa6 commit 46b37d0

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
1616
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1717
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
18+
#include "mlir/Dialect/CommonFolders.h"
1819
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1920
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2021
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
2728
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
2829
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2930
#include "mlir/Dialect/Transform/Utils/Utils.h"
31+
#include "mlir/Dialect/UB/IR/UBOps.h"
3032
#include "mlir/Dialect/Utils/IndexingUtils.h"
3133
#include "mlir/Dialect/Utils/StaticValueUtils.h"
3234
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -1985,22 +1987,27 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
19851987

19861988
// Convert the padding values to attributes.
19871989
SmallVector<Attribute> paddingValues;
1988-
for (auto const &it :
1990+
for (auto const &[untypedAttr, elementOrTensorType] :
19891991
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1990-
auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1992+
1993+
if (isa<ub::PoisonAttr>(untypedAttr)) {
1994+
paddingValues.push_back(untypedAttr);
1995+
continue;
1996+
}
1997+
auto attr = dyn_cast<TypedAttr>(untypedAttr);
19911998
if (!attr) {
1992-
emitOpError("expects padding values to be typed attributes");
1999+
emitOpError("expects padding values to be typed attributes or poison");
19932000
return DiagnosedSilenceableFailure::definiteFailure();
19942001
}
1995-
Type elementType = getElementTypeOrSelf(std::get<1>(it));
2002+
Type elementType = getElementTypeOrSelf(elementOrTensorType);
19962003
// Try to parse string attributes to obtain an attribute of element type.
19972004
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
19982005
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
19992006
stringAttr, getContext(), elementType,
20002007
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
20012008
if (!parsedAttr || parsedAttr.getType() != elementType) {
20022009
auto diag = this->emitOpError("expects a padding that parses to ")
2003-
<< elementType << ", got " << std::get<0>(it);
2010+
<< elementType << ", got " << untypedAttr;
20042011
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
20052012
return DiagnosedSilenceableFailure::definiteFailure();
20062013
}
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
22352242
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
22362243
auto attr = dyn_cast<TypedAttr>(untypedAttr);
22372244
Type elementType = getElementTypeOrSelf(elementOrTensorType);
2245+
2246+
if (isa<ub::PoisonAttr>(untypedAttr)) {
2247+
paddingValues.push_back(untypedAttr);
2248+
continue;
2249+
}
22382250
if (!attr) {
2239-
emitOpError("expects padding values to be typed attributes");
2251+
emitOpError("expects padding values to be typed attributes or poison");
22402252
return DiagnosedSilenceableFailure::definiteFailure();
22412253
}
22422254
// Try to parse string attributes to obtain an attribute of element type.

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
55
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
66
{
7-
// %goo = ub.poison : f32
87
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
98
func.return %0 : tensor<24x25xf32>
109
}
@@ -15,12 +14,11 @@ module attributes {transform.with_named_sequence} {
1514
: (!transform.any_op) -> !transform.any_op
1615

1716
// Tile to 5 then pad to 8
18-
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
17+
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
1918
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2019

2120
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
22-
// padding_values= [poison, 0.0 : f32]
23-
padding_values= [0.0 : f32, 0.0 : f32]
21+
padding_values= [#ub.poison, 0.0 : f32]
2422
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2523

2624
transform.yield
@@ -35,9 +33,9 @@ func.func @pad_lhs(
3533
-> tensor<24x25xf32>
3634
{
3735
// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
38-
// CHECK: tensor.pad %{{.*}}
36+
// CHECK: tensor.pad %{{.*}}
3937
// CHECK: : tensor<?x12xf32> to tensor<8x12xf32>
40-
// CHECK: tensor.pad %{{.*}}
38+
// CHECK: tensor.pad %{{.*}}
4139
// CHECK: : tensor<?x25xf32> to tensor<8x25xf32>
4240
// CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
4341
// CHECK: tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
@@ -94,7 +92,7 @@ module {
9492
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
9593
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
9694
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
97-
transform.yield
95+
transform.yield
9896
}
9997
}
10098
}
@@ -149,7 +147,7 @@ module {
149147
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
150148
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
151149
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
152-
transform.yield
150+
transform.yield
153151
}
154152
}
155153
}

0 commit comments

Comments
 (0)