@@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
576576// FuseOp
577577// ===----------------------------------------------------------------------===//
578578
579+ void transform::FuseOp::build (OpBuilder &builder, OperationState &result,
580+ TypeRange loopTypes, Value target,
581+ ArrayRef<int64_t > staticTileSizes,
582+ ArrayRef<int64_t > staticTileInterchange,
583+ bool applyCleanup, bool useForall) {
584+ return build (
585+ builder, result, loopTypes,
586+ /* target=*/ target,
587+ /* mixedTileSizes=*/
588+ getAsOpFoldResult (builder.getI64ArrayAttr (staticTileSizes)),
589+ /* mixedTileInterchange=*/
590+ getAsOpFoldResult (builder.getI64ArrayAttr (staticTileInterchange)),
591+ applyCleanup, useForall);
592+ }
593+
594+ void transform::FuseOp::build (OpBuilder &builder, OperationState &result,
595+ Value target, ArrayRef<int64_t > staticTileSizes,
596+ ArrayRef<int64_t > staticTileInterchange,
597+ bool applyCleanup, bool useForall) {
598+ return build (
599+ builder, result,
600+ /* target=*/ target,
601+ /* mixedTileSizes=*/
602+ getAsOpFoldResult (builder.getI64ArrayAttr (staticTileSizes)),
603+ /* mixedTileInterchange=*/
604+ getAsOpFoldResult (builder.getI64ArrayAttr (staticTileInterchange)),
605+ applyCleanup, useForall);
606+ }
607+
608+ void transform::FuseOp::build (OpBuilder &builder, OperationState &result,
609+ Value target,
610+ ArrayRef<OpFoldResult> mixedTileSizes,
611+ ArrayRef<OpFoldResult> mixedTileInterchange,
612+ bool applyCleanup, bool useForall) {
613+ // Loop types are automaticaly splat by the callee, setting up one is
614+ // enough.
615+ SmallVector<Type> loopTypes (1 , builder.getType <transform::AnyOpType>());
616+ build (builder, result, loopTypes, target, mixedTileSizes,
617+ mixedTileInterchange, applyCleanup, useForall);
618+ }
619+
620+ void transform::FuseOp::build (OpBuilder &builder, OperationState &result,
621+ TypeRange loopTypes, Value target,
622+ ArrayRef<OpFoldResult> mixedTileSizes,
623+ ArrayRef<OpFoldResult> mixedTileInterchange,
624+ bool applyCleanup, bool useForall) {
625+ SmallVector<int64_t > staticTileSizes;
626+ SmallVector<Value> dynamicTileSizes;
627+ dispatchIndexOpFoldResults (mixedTileSizes, dynamicTileSizes, staticTileSizes);
628+ SmallVector<int64_t > staticTileInterchange;
629+ SmallVector<Value> dynamicTileInterchange;
630+ dispatchIndexOpFoldResults (mixedTileInterchange, dynamicTileInterchange,
631+ staticTileInterchange);
632+ // Call the default builder which sets up the proper operands segment sizes
633+ // attributes for multiple variadic operands. In the absence of this,
634+ // horrible bugs ensue.
635+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr (staticTileSizes);
636+ auto staticTileInterchangeAttr =
637+ builder.getDenseI64ArrayAttr (staticTileInterchange);
638+ unsigned numExpectedLoops =
639+ useForall ? 1 : staticTileSizes.size () - llvm::count (staticTileSizes, 0 );
640+ SmallVector<Type> resultTypes;
641+ resultTypes.reserve (numExpectedLoops);
642+ assert ((loopTypes.size () == 1 || loopTypes.size () == numExpectedLoops) &&
643+ " expected one loop type or as many as loops" );
644+ if (loopTypes.size () == 1 )
645+ resultTypes.append (numExpectedLoops, loopTypes[0 ]);
646+ else
647+ llvm::append_range (resultTypes, loopTypes);
648+ build (builder, result, /* transformed=*/ target.getType (),
649+ /* loops=*/ resultTypes,
650+ /* target=*/ target,
651+ /* tile_sizes=*/ dynamicTileSizes,
652+ /* tile_interchange=*/ dynamicTileInterchange,
653+ /* static_tile_sizes=*/ staticTileSizesAttr,
654+ /* static_tile_interchange=*/ staticTileInterchangeAttr,
655+ /* apply_cleanup=*/ applyCleanup,
656+ /* use_forall=*/ useForall);
657+ }
658+
579659// / Apply a tiling transformation to all payload ops and store both the
580660// / tiled operation as well as the created tile loops.
581661template <typename Range>
@@ -630,10 +710,18 @@ DiagnosedSilenceableFailure
630710transform::FuseOp::apply (transform::TransformRewriter &rewriter,
631711 mlir::transform::TransformResults &transformResults,
632712 mlir::transform::TransformState &state) {
633- SmallVector<int64_t > tileSizes =
634- extractFromIntegerArrayAttr<int64_t >(getTileSizes ());
635- SmallVector<int64_t > tileInterchange =
636- extractFromIntegerArrayAttr<int64_t >(getTileInterchange ());
713+ auto transformOp = cast<TransformOpInterface>(getOperation ());
714+
715+ SmallVector<int64_t > tileSizes;
716+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults (
717+ state, transformOp, getMixedTileSizes (), tileSizes);
718+ if (!status.succeeded ())
719+ return status;
720+ SmallVector<int64_t > tileInterchange;
721+ status = reifyMixedParamAndHandleResults (
722+ state, transformOp, getMixedTileInterchange (), tileInterchange);
723+ if (!status.succeeded ())
724+ return status;
637725
638726 scf::SCFTilingOptions tilingOptions;
639727 tilingOptions.interchangeVector = tileInterchange;
@@ -671,17 +759,18 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
671759}
672760
673761LogicalResult transform::FuseOp::verify () {
674- SmallVector<int64_t > permutation =
675- extractFromIntegerArrayAttr<int64_t >(getTileInterchange ());
676- auto sequence = llvm::to_vector (llvm::seq<int64_t >(0 , permutation.size ()));
677- if (!std::is_permutation (sequence.begin (), sequence.end (),
678- permutation.begin (), permutation.end ())) {
679- return emitOpError () << " expects interchange to be a permutation, found "
680- << getTileInterchange ();
762+ ArrayRef<int64_t > permutation = getStaticTileInterchange ();
763+ if (!llvm::any_of (permutation,
764+ [](int64_t v) { return ShapedType::isDynamic (v); })) {
765+ auto sequence = llvm::to_vector (llvm::seq<int64_t >(0 , permutation.size ()));
766+ if (!std::is_permutation (sequence.begin (), sequence.end (),
767+ permutation.begin (), permutation.end ())) {
768+ return emitOpError () << " expects interchange to be a permutation, found "
769+ << getTileInterchange ();
770+ }
681771 }
682772
683- SmallVector<int64_t > sizes =
684- extractFromIntegerArrayAttr<int64_t >(getTileSizes ());
773+ ArrayRef<int64_t > sizes = getStaticTileSizes ();
685774 size_t numExpectedLoops =
686775 getUseForall () ? 1 : sizes.size () - llvm::count (sizes, 0 );
687776 if (numExpectedLoops != getNumResults () - 1 )
@@ -690,6 +779,49 @@ LogicalResult transform::FuseOp::verify() {
690779 return success ();
691780}
692781
782+ SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes () {
783+ ValueRange dynamicValues = getTileSizes ();
784+ ArrayRef<int64_t > staticValues = getStaticTileSizes ();
785+ SmallVector<OpFoldResult> results;
786+ results.reserve (staticValues.size ());
787+ unsigned dynamicPos = 0 ;
788+ Builder builder (getContext ());
789+ for (int64_t size : staticValues) {
790+ if (size == ShapedType::kDynamic ) {
791+ results.push_back (dynamicValues[dynamicPos++]);
792+ } else {
793+ results.push_back (builder.getIndexAttr (size));
794+ }
795+ }
796+ return results;
797+ }
798+
799+ SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange () {
800+ ValueRange dynamicValues = getTileInterchange ();
801+ ArrayRef<int64_t > staticValues = getStaticTileInterchange ();
802+ SmallVector<OpFoldResult> results;
803+ results.reserve (staticValues.size ());
804+ unsigned dynamicPos = 0 ;
805+ Builder builder (getContext ());
806+ for (int64_t size : staticValues) {
807+ if (size == ShapedType::kDynamic ) {
808+ results.push_back (dynamicValues[dynamicPos++]);
809+ } else {
810+ results.push_back (builder.getIndexAttr (size));
811+ }
812+ }
813+ return results;
814+ }
815+
816+ void transform::FuseOp::getEffects (
817+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
818+ consumesHandle (getTargetMutable (), effects);
819+ onlyReadsHandle (getTileSizesMutable (), effects);
820+ onlyReadsHandle (getTileInterchangeMutable (), effects);
821+ producesHandle (getOperation ()->getOpResults (), effects);
822+ modifiesPayload (effects);
823+ }
824+
693825// ===----------------------------------------------------------------------===//
694826// FuseIntoContainingOp
695827// ===----------------------------------------------------------------------===//
0 commit comments