Skip to content

Commit 79383a6

Browse files
committed
address review comments
1 parent af17ff1 commit 79383a6

File tree

4 files changed

+47
-43
lines changed

4 files changed

+47
-43
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
400400
TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> {
401401
let description = [{
402402
Tiles the operations pointed to by the target handle and fuses their
403-
producers greedily using the options provided as attributes.
403+
producers greedily using the options provided as attributes. Tile sizes
404+
and loop interchange permutation can be provided as either static
405+
attributes or dynamic values (transform parameters or payload handles).
404406

405407
If `apply_cleanup` is true then slice canonicalization is applied between
406408
fusion steps. If `use_forall` is true then tiling method generates a

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -759,14 +759,21 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
759759
}
760760

761761
LogicalResult transform::FuseOp::verify() {
762+
auto iterspace_dim = getStaticTileSizes().size();
762763
ArrayRef<int64_t> permutation = getStaticTileInterchange();
763-
if (!llvm::any_of(permutation,
764-
[](int64_t v) { return ShapedType::isDynamic(v); })) {
765-
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
766-
if (!std::is_permutation(sequence.begin(), sequence.end(),
767-
permutation.begin(), permutation.end())) {
768-
return emitOpError() << "expects interchange to be a permutation, found "
769-
<< getTileInterchange();
764+
if (permutation.size() > iterspace_dim)
765+
return emitOpError()
766+
<< "interchange length exceeds iteration space dimensions ("
767+
<< iterspace_dim << "), found " << getTileInterchange();
768+
llvm::SmallDenseSet<int64_t, 4> seen;
769+
for (int64_t v : permutation) {
770+
if (!ShapedType::isDynamic(v)) {
771+
if (v < 0 || v >= iterspace_dim)
772+
return emitOpError() << "expects interchange values to be in range [0, "
773+
<< iterspace_dim << "), found: " << v;
774+
auto result = seen.insert(v);
775+
if (!result.second)
776+
return emitOpError() << "found duplicate interchange value: " << v;
770777
}
771778
}
772779

@@ -780,37 +787,12 @@ LogicalResult transform::FuseOp::verify() {
780787
}
781788

782789
SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
783-
ValueRange dynamicValues = getTileSizes();
784-
ArrayRef<int64_t> staticValues = getStaticTileSizes();
785-
SmallVector<OpFoldResult> results;
786-
results.reserve(staticValues.size());
787-
unsigned dynamicPos = 0;
788-
Builder builder(getContext());
789-
for (int64_t size : staticValues) {
790-
if (size == ShapedType::kDynamic) {
791-
results.push_back(dynamicValues[dynamicPos++]);
792-
} else {
793-
results.push_back(builder.getIndexAttr(size));
794-
}
795-
}
796-
return results;
790+
return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
797791
}
798792

799793
SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
800-
ValueRange dynamicValues = getTileInterchange();
801-
ArrayRef<int64_t> staticValues = getStaticTileInterchange();
802-
SmallVector<OpFoldResult> results;
803-
results.reserve(staticValues.size());
804-
unsigned dynamicPos = 0;
805-
Builder builder(getContext());
806-
for (int64_t size : staticValues) {
807-
if (size == ShapedType::kDynamic) {
808-
results.push_back(dynamicValues[dynamicPos++]);
809-
} else {
810-
results.push_back(builder.getIndexAttr(size));
811-
}
812-
}
813-
return results;
794+
return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
795+
getContext());
814796
}
815797

816798
void transform::FuseOp::getEffects(

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def __init__(
146146
*,
147147
tile_sizes: Optional[MixedValues] = None,
148148
tile_interchange: Optional[MixedValues] = None,
149-
apply_cleanup: Optional[bool] = False,
150-
use_forall: Optional[bool] = False,
149+
apply_cleanup: bool = False,
150+
use_forall: bool = False,
151151
loc=None,
152152
ip=None,
153153
):
@@ -160,8 +160,8 @@ def __init__(
160160
*,
161161
tile_sizes: Optional[MixedValues] = None,
162162
tile_interchange: Optional[MixedValues] = None,
163-
apply_cleanup: Optional[bool] = False,
164-
use_forall: Optional[bool] = False,
163+
apply_cleanup: bool = False,
164+
use_forall: bool = False,
165165
loc=None,
166166
ip=None,
167167
):
@@ -174,8 +174,8 @@ def __init__(
174174
*,
175175
tile_sizes: Optional[MixedValues] = None,
176176
tile_interchange: Optional[MixedValues] = None,
177-
apply_cleanup: Optional[bool] = False,
178-
use_forall: Optional[bool] = False,
177+
apply_cleanup: bool = False,
178+
use_forall: bool = False,
179179
loc=None,
180180
ip=None,
181181
):
@@ -192,7 +192,7 @@ def __init__(
192192
_,
193193
) = _dispatch_dynamic_index_list(tile_interchange)
194194
num_loops = (
195-
1 if use_forall else sum(0 if v == 0 else 1 for v in static_tile_sizes)
195+
1 if use_forall else sum(1 for v in static_tile_sizes if v != 0)
196196
)
197197

198198
if isinstance(loop_types_or_target, (Operation, Value, OpView)):

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,26 @@ def testFuseOpParams(target):
158158
# CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
159159

160160

161+
@run
162+
@create_sequence
163+
def testFuseOpHandles(target):
164+
size1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
165+
ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
166+
structured.FuseOp(
167+
target,
168+
tile_sizes=[size1, 8],
169+
tile_interchange=[ichange1, 1],
170+
)
171+
# CHECK-LABEL: TEST: testFuseOpHandles
172+
# CHECK: transform.sequence
173+
# CHECK: %[[H:.*]] = transform.structured.match
174+
# CHECK: %[[I:.*]] = transform.structured.match
175+
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
176+
# CHECK-SAME: tile_sizes [%[[H]], 8]
177+
# CHECK-SAME: interchange [%[[I]], 1]
178+
# CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
179+
180+
161181
@run
162182
@create_sequence
163183
def testFuseOpAttributes(target):

0 commit comments

Comments
 (0)