1818#include " mlir/Dialect/Transform/IR/TransformDialect.h"
1919#include " mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2020#include " mlir/Dialect/Utils/StaticValueUtils.h"
21+ #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
2122#include " mlir/IR/Dominance.h"
2223#include " mlir/IR/OpImplementation.h"
2324#include " mlir/Interfaces/TilingInterface.h"
@@ -60,8 +61,7 @@ template <typename Range>
6061static LogicalResult
6162applyTileAndFuseToAll (RewriterBase &rewriter, Operation *transformOp,
6263 Range &&payloadOps, unsigned numLoops,
63- ArrayRef<OpFoldResult> tileSizes,
64- ArrayRef<int64_t > interchange, bool useForall,
64+ scf::SCFTilingOptions tilingOptions,
6565 TransformResults &transformResults) {
6666 SmallVector<Operation *> tiledOps;
6767 SmallVector<SmallVector<Operation *>> loopOps (numLoops);
@@ -83,12 +83,6 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
8383 }
8484 }
8585
86- scf::SCFTilingOptions tilingOptions;
87- tilingOptions.setTileSizes (tileSizes).setInterchange (interchange);
88- if (useForall) {
89- tilingOptions.setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
90- }
91-
9286 scf::SCFTileAndFuseOptions tileAndFuseOptions;
9387 tileAndFuseOptions.setTilingOptions (tilingOptions);
9488
@@ -157,10 +151,16 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
157151 SmallVector<OpFoldResult> tileSizesOfr =
158152 getAsIndexOpFoldResult (rewriter.getContext (), tileSizes);
159153
154+ scf::SCFTilingOptions tilingOptions;
155+ tilingOptions.setTileSizes (tileSizesOfr).setInterchange (tileInterchange);
156+ if (getUseForall ()) {
157+ tilingOptions.setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
158+ }
159+
160160 LogicalResult result = applyTileAndFuseToAll (
161161 rewriter, getOperation (), state.getPayloadOps (getTarget ()),
162- tileSizes.size () - llvm::count (tileSizes, 0 ), tileSizesOfr ,
163- tileInterchange, getUseForall (), transformResults);
162+ tileSizes.size () - llvm::count (tileSizes, 0 ), tilingOptions ,
163+ transformResults);
164164 return failed (result) ? DiagnosedSilenceableFailure::definiteFailure ()
165165 : DiagnosedSilenceableFailure::success ();
166166}
@@ -399,6 +399,75 @@ void transform::TestFuseUsingForallOp::getEffects(
399399 modifiesPayload (effects);
400400}
401401
402+ // ===----------------------------------------------------------------------===//
403+ // TestTileAndFuseOuterParallelPartialReduction
404+ // ===----------------------------------------------------------------------===//
405+
406+ DiagnosedSilenceableFailure
407+ transform::TestTileAndFuseOuterParallelPartialReductionOp::apply (
408+ TransformRewriter &rewriter, TransformResults &transformResults,
409+ TransformState &state) {
410+ auto target =
411+ dyn_cast<TilingInterface>(*state.getPayloadOps (getRootOp ()).begin ());
412+ if (!target) {
413+ emitOpError (" expected root operation to implement `TilingInterface`" );
414+ return DiagnosedSilenceableFailure::definiteFailure ();
415+ }
416+
417+ SmallVector<unsigned > reductionDims =
418+ extractFromIntegerArrayAttr<unsigned >(getReductionDims ());
419+ if (reductionDims.empty ()) {
420+ for (auto [index, iterator] :
421+ llvm::enumerate (target.getLoopIteratorTypes ()))
422+ if (iterator == utils::IteratorType::reduction)
423+ reductionDims.push_back (index);
424+ }
425+
426+ if (reductionDims.empty ()) {
427+ emitOpError (
428+ " no reduction dimension specified or found in the target operation" );
429+ return DiagnosedSilenceableFailure::definiteFailure ();
430+ }
431+
432+ SmallVector<int64_t > reductionTileSizes =
433+ extractFromIntegerArrayAttr<int64_t >(getTileSizes ());
434+ if (reductionTileSizes.size () != reductionDims.size ()) {
435+ emitOpError (
436+ " missing tile sizes for reduction dimensions that are to be tiled" );
437+ return DiagnosedSilenceableFailure::definiteFailure ();
438+ }
439+
440+ // Adjust tile sizes so that it corresponds to the reduction iterator types.
441+ SmallVector<OpFoldResult> tileSizes;
442+ int reductionTileSizeNum = 0 ;
443+ OpFoldResult zero = rewriter.getIndexAttr (0 );
444+ for (auto iterator : target.getLoopIteratorTypes ()) {
445+ if (iterator == utils::IteratorType::parallel) {
446+ tileSizes.push_back (zero);
447+ continue ;
448+ }
449+ tileSizes.push_back (
450+ rewriter.getIndexAttr (reductionTileSizes[reductionTileSizeNum++]));
451+ }
452+
453+ scf::SCFTilingOptions tilingOptions;
454+ tilingOptions.setTileSizes (tileSizes)
455+ .setLoopType (scf::SCFTilingOptions::LoopType::ForallOp)
456+ .setReductionTilingStrategy (
457+ ReductionTilingStrategy::PartialReductionOuterParallel)
458+ .setReductionDims (reductionDims);
459+ if (auto mapping = getMapping ()) {
460+ tilingOptions.setMapping (getMapping ().value ());
461+ }
462+
463+ LogicalResult result = applyTileAndFuseToAll (
464+ rewriter, getOperation (), state.getPayloadOps (getRootOp ()),
465+ /* numLoops =*/ 1 , tilingOptions, transformResults);
466+
467+ return failed (result) ? DiagnosedSilenceableFailure::definiteFailure ()
468+ : DiagnosedSilenceableFailure::success ();
469+ }
470+
402471#define GET_OP_CLASSES
403472#include " TestTilingInterfaceTransformOps.cpp.inc"
404473
0 commit comments