Skip to content

Commit 3fa3ee1

Browse files
committed
FuseOp tile sizes and interchange args accept dynamic values
1 parent 58f56a2 commit 3fa3ee1

File tree

7 files changed

+295
-67
lines changed

7 files changed

+295
-67
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,33 +395,72 @@ def EliminateLinalgOpAnchoredEmptyTensorsOp
395395
//===----------------------------------------------------------------------===//
396396

397397
def FuseOp : Op<Transform_Dialect, "structured.fuse",
398-
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
399-
DeclareOpInterfaceMethods<TransformOpInterface>,
400-
ReportTrackingListenerFailuresOpTrait]> {
398+
[AttrSizedOperandSegments,
399+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
400+
TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> {
401401
let description = [{
402402
Tiles the operations pointed to by the target handle and fuses their
403403
producers greedily using the options provided as attributes.
404404

405405
If `apply_cleanup` is true then slice canonicalization is applied between
406-
fusion steps.
406+
fusion steps. If `use_forall` is true then tiling method generates a
407+
`scf.forall` loop instead of `scf.for` loops.
407408
}];
408409

409410
let arguments =
410411
(ins TransformHandleTypeInterface:$target,
411-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
412-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
413-
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup,
414-
DefaultValuedAttr<BoolAttr, "false">:$use_forall);
412+
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes,
413+
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange,
414+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
415+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange,
416+
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup,
417+
DefaultValuedAttr<BoolAttr, "false">:$use_forall);
415418
let results = (outs TransformHandleTypeInterface:$transformed,
416419
Variadic<TransformHandleTypeInterface>:$loops);
420+
let builders = [
421+
OpBuilder<(ins "TypeRange":$loopTypes,
422+
"Value":$target,
423+
"ArrayRef<int64_t>":$staticTileSizes,
424+
"ArrayRef<int64_t>":$staticTileInterchange,
425+
CArg<"bool", "false">:$applyCleanup,
426+
CArg<"bool", "false">:$useForall)>,
427+
OpBuilder<(ins "TypeRange":$loopTypes,
428+
"Value":$target,
429+
"ArrayRef<OpFoldResult>":$mixedTileSizes,
430+
"ArrayRef<OpFoldResult>":$mixedTileInterchange,
431+
CArg<"bool", "false">:$applyCleanup,
432+
CArg<"bool", "false">:$useForall)>,
433+
OpBuilder<(ins "Value":$target,
434+
"ArrayRef<int64_t>":$staticTileSizes,
435+
"ArrayRef<int64_t>":$staticTileInterchange,
436+
CArg<"bool", "false">:$applyCleanup,
437+
CArg<"bool", "false">:$useForall)>,
438+
OpBuilder<(ins "Value":$target,
439+
"ArrayRef<OpFoldResult>":$mixedTileSizes,
440+
"ArrayRef<OpFoldResult>":$mixedTileInterchange,
441+
CArg<"bool", "false">:$applyCleanup,
442+
CArg<"bool", "false">:$useForall)>,
443+
];
417444

418445
let assemblyFormat = [{
419-
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
446+
$target
447+
(`tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes)^)?
448+
(`interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange)^)?
420449
(`apply_cleanup` `=` $apply_cleanup^)?
421450
(`use_forall` `=` $use_forall^)? attr-dict
422451
`:` functional-type(operands, results)
423452
}];
424453
let hasVerifier = 1;
454+
455+
let extraClassDeclaration = [{
456+
::mlir::DiagnosedSilenceableFailure apply(
457+
::mlir::transform::TransformRewriter &rewriter,
458+
::mlir::transform::TransformResults &transformResults,
459+
::mlir::transform::TransformState &state);
460+
461+
::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileSizes();
462+
::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileInterchange();
463+
}];
425464
}
426465

427466
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 145 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
576576
// FuseOp
577577
//===----------------------------------------------------------------------===//
578578

579+
void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
580+
TypeRange loopTypes, Value target,
581+
ArrayRef<int64_t> staticTileSizes,
582+
ArrayRef<int64_t> staticTileInterchange,
583+
bool applyCleanup, bool useForall) {
584+
return build(
585+
builder, result, loopTypes,
586+
/*target=*/target,
587+
/*mixedTileSizes=*/
588+
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
589+
/*mixedTileInterchange=*/
590+
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
591+
applyCleanup, useForall);
592+
}
593+
594+
void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
595+
Value target, ArrayRef<int64_t> staticTileSizes,
596+
ArrayRef<int64_t> staticTileInterchange,
597+
bool applyCleanup, bool useForall) {
598+
return build(
599+
builder, result,
600+
/*target=*/target,
601+
/*mixedTileSizes=*/
602+
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
603+
/*mixedTileInterchange=*/
604+
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
605+
applyCleanup, useForall);
606+
}
607+
608+
void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
609+
Value target,
610+
ArrayRef<OpFoldResult> mixedTileSizes,
611+
ArrayRef<OpFoldResult> mixedTileInterchange,
612+
bool applyCleanup, bool useForall) {
613+
// Loop types are automaticaly splat by the callee, setting up one is
614+
// enough.
615+
SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
616+
build(builder, result, loopTypes, target, mixedTileSizes,
617+
mixedTileInterchange, applyCleanup, useForall);
618+
}
619+
620+
void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
621+
TypeRange loopTypes, Value target,
622+
ArrayRef<OpFoldResult> mixedTileSizes,
623+
ArrayRef<OpFoldResult> mixedTileInterchange,
624+
bool applyCleanup, bool useForall) {
625+
SmallVector<int64_t> staticTileSizes;
626+
SmallVector<Value> dynamicTileSizes;
627+
dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
628+
SmallVector<int64_t> staticTileInterchange;
629+
SmallVector<Value> dynamicTileInterchange;
630+
dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
631+
staticTileInterchange);
632+
// Call the default builder which sets up the proper operands segment sizes
633+
// attributes for multiple variadic operands. In the absence of this,
634+
// horrible bugs ensue.
635+
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
636+
auto staticTileInterchangeAttr =
637+
builder.getDenseI64ArrayAttr(staticTileInterchange);
638+
unsigned numExpectedLoops =
639+
useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
640+
SmallVector<Type> resultTypes;
641+
resultTypes.reserve(numExpectedLoops);
642+
assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
643+
"expected one loop type or as many as loops");
644+
if (loopTypes.size() == 1)
645+
resultTypes.append(numExpectedLoops, loopTypes[0]);
646+
else
647+
llvm::append_range(resultTypes, loopTypes);
648+
build(builder, result, /*transformed=*/target.getType(),
649+
/*loops=*/resultTypes,
650+
/*target=*/target,
651+
/*tile_sizes=*/dynamicTileSizes,
652+
/*tile_interchange=*/dynamicTileInterchange,
653+
/*static_tile_sizes=*/staticTileSizesAttr,
654+
/*static_tile_interchange=*/staticTileInterchangeAttr,
655+
/*apply_cleanup=*/applyCleanup,
656+
/*use_forall=*/useForall);
657+
}
658+
579659
/// Apply a tiling transformation to all payload ops and store both the
580660
/// tiled operation as well as the created tile loops.
581661
template <typename Range>
@@ -630,10 +710,18 @@ DiagnosedSilenceableFailure
630710
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
631711
mlir::transform::TransformResults &transformResults,
632712
mlir::transform::TransformState &state) {
633-
SmallVector<int64_t> tileSizes =
634-
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
635-
SmallVector<int64_t> tileInterchange =
636-
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
713+
auto transformOp = cast<TransformOpInterface>(getOperation());
714+
715+
SmallVector<int64_t> tileSizes;
716+
DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
717+
state, transformOp, getMixedTileSizes(), tileSizes);
718+
if (!status.succeeded())
719+
return status;
720+
SmallVector<int64_t> tileInterchange;
721+
status = reifyMixedParamAndHandleResults(
722+
state, transformOp, getMixedTileInterchange(), tileInterchange);
723+
if (!status.succeeded())
724+
return status;
637725

638726
scf::SCFTilingOptions tilingOptions;
639727
tilingOptions.interchangeVector = tileInterchange;
@@ -671,17 +759,18 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
671759
}
672760

673761
LogicalResult transform::FuseOp::verify() {
674-
SmallVector<int64_t> permutation =
675-
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
676-
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
677-
if (!std::is_permutation(sequence.begin(), sequence.end(),
678-
permutation.begin(), permutation.end())) {
679-
return emitOpError() << "expects interchange to be a permutation, found "
680-
<< getTileInterchange();
762+
ArrayRef<int64_t> permutation = getStaticTileInterchange();
763+
if (!llvm::any_of(permutation,
764+
[](int64_t v) { return ShapedType::isDynamic(v); })) {
765+
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
766+
if (!std::is_permutation(sequence.begin(), sequence.end(),
767+
permutation.begin(), permutation.end())) {
768+
return emitOpError() << "expects interchange to be a permutation, found "
769+
<< getTileInterchange();
770+
}
681771
}
682772

683-
SmallVector<int64_t> sizes =
684-
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
773+
ArrayRef<int64_t> sizes = getStaticTileSizes();
685774
size_t numExpectedLoops =
686775
getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
687776
if (numExpectedLoops != getNumResults() - 1)
@@ -690,6 +779,49 @@ LogicalResult transform::FuseOp::verify() {
690779
return success();
691780
}
692781

782+
SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
783+
ValueRange dynamicValues = getTileSizes();
784+
ArrayRef<int64_t> staticValues = getStaticTileSizes();
785+
SmallVector<OpFoldResult> results;
786+
results.reserve(staticValues.size());
787+
unsigned dynamicPos = 0;
788+
Builder builder(getContext());
789+
for (int64_t size : staticValues) {
790+
if (size == ShapedType::kDynamic) {
791+
results.push_back(dynamicValues[dynamicPos++]);
792+
} else {
793+
results.push_back(builder.getIndexAttr(size));
794+
}
795+
}
796+
return results;
797+
}
798+
799+
SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
800+
ValueRange dynamicValues = getTileInterchange();
801+
ArrayRef<int64_t> staticValues = getStaticTileInterchange();
802+
SmallVector<OpFoldResult> results;
803+
results.reserve(staticValues.size());
804+
unsigned dynamicPos = 0;
805+
Builder builder(getContext());
806+
for (int64_t size : staticValues) {
807+
if (size == ShapedType::kDynamic) {
808+
results.push_back(dynamicValues[dynamicPos++]);
809+
} else {
810+
results.push_back(builder.getIndexAttr(size));
811+
}
812+
}
813+
return results;
814+
}
815+
816+
void transform::FuseOp::getEffects(
817+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
818+
consumesHandle(getTargetMutable(), effects);
819+
onlyReadsHandle(getTileSizesMutable(), effects);
820+
onlyReadsHandle(getTileInterchangeMutable(), effects);
821+
producesHandle(getOperation()->getOpResults(), effects);
822+
modifiesPayload(effects);
823+
}
824+
693825
//===----------------------------------------------------------------------===//
694826
// FuseIntoContainingOp
695827
//===----------------------------------------------------------------------===//

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def __init__(
144144
loop_types: Union[Type, Sequence[Type]],
145145
target: Union[Operation, Value, OpView],
146146
*,
147-
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
148-
tile_interchange: OptionalIntList = None,
147+
tile_sizes: Optional[MixedValues] = None,
148+
tile_interchange: Optional[MixedValues] = None,
149149
apply_cleanup: Optional[bool] = False,
150150
use_forall: Optional[bool] = False,
151151
loc=None,
@@ -158,8 +158,8 @@ def __init__(
158158
self,
159159
target: Union[Operation, Value, OpView],
160160
*,
161-
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
162-
tile_interchange: OptionalIntList = None,
161+
tile_sizes: Optional[MixedValues] = None,
162+
tile_interchange: Optional[MixedValues] = None,
163163
apply_cleanup: Optional[bool] = False,
164164
use_forall: Optional[bool] = False,
165165
loc=None,
@@ -172,18 +172,26 @@ def __init__(
172172
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
173173
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
174174
*,
175-
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
176-
tile_interchange: OptionalIntList = None,
175+
tile_sizes: Optional[MixedValues] = None,
176+
tile_interchange: Optional[MixedValues] = None,
177177
apply_cleanup: Optional[bool] = False,
178178
use_forall: Optional[bool] = False,
179179
loc=None,
180180
ip=None,
181181
):
182182
tile_sizes = tile_sizes if tile_sizes else []
183183
tile_interchange = tile_interchange if tile_interchange else []
184-
_, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
185-
_, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
186-
num_loops = 1 if use_forall else sum(0 if v == 0 else 1 for v in tile_sizes)
184+
(
185+
dynamic_tile_sizes,
186+
static_tile_sizes,
187+
_,
188+
) = _dispatch_dynamic_index_list(tile_sizes)
189+
(
190+
dynamic_tile_interchange,
191+
static_tile_interchange,
192+
_,
193+
) = _dispatch_dynamic_index_list(tile_interchange)
194+
num_loops = 1 if use_forall else sum(0 if v == 0 else 1 for v in static_tile_sizes)
187195

188196
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
189197
loop_types = [transform.AnyOpType.get()] * num_loops
@@ -200,8 +208,10 @@ def __init__(
200208
target.type,
201209
loop_types,
202210
target,
203-
tile_sizes=tile_sizes,
204-
tile_interchange=tile_interchange,
211+
tile_sizes=dynamic_tile_sizes,
212+
tile_interchange=dynamic_tile_interchange,
213+
static_tile_sizes=static_tile_sizes,
214+
static_tile_interchange=static_tile_interchange,
205215
apply_cleanup=apply_cleanup,
206216
use_forall=use_forall,
207217
loc=loc,

0 commit comments

Comments
 (0)