diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h index 2250db823b551..3652435e4c59e 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -223,6 +223,10 @@ class DeadCodeAnalysis : public DataFlowAnalysis { /// Get the constant values of the operands of the operation. Returns /// std::nullopt if any of the operand lattices are uninitialized. std::optional> getOperandValues(Operation *op); + + /// Get the constant values of the operands of the operation. + /// If the operand lattices are uninitialized, add a null attribute for those. + SmallVector getOperandValuesBestEffort(Operation *op); /// The top-level operation the analysis is running on. This is used to detect /// if a callable is outside the scope of the analysis and thus must be diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td index c07ab9deca48c..efbe15eb00d7a 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -138,6 +138,16 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> { return $_op.getOperand($_op.getStoredValOperandIndex()); }] >, + InterfaceMethod< + /*desc=*/"Returns the value to store.", + /*retTy=*/"::mlir::OpOperand&", + /*methodName=*/"getValueToStoreMutable", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return $_op->getOpOperand($_op.getStoredValOperandIndex()); + }] + >, ]; } diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 6cd3408e2b2e9..e6450defb4376 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -119,7 +119,8 @@ def AffineForOp : Affine_Op<"for", ImplicitAffineTerminator, ConditionallySpeculatable, RecursiveMemoryEffects, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h index 7fe1f6d48ceeb..1822d535dfe25 100644 --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -16,6 +16,7 @@ #define MLIR_DIALECT_AFFINE_LOOPUTILS_H #include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" #include @@ -101,7 +102,7 @@ LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef shifts, /// Identify valid and profitable bands of loops to tile. This is currently just /// a temporary placeholder to test the mechanics of tiled code generation. /// Returns all maximal outermost perfect loop nests to tile. -void getTileableBands(func::FuncOp f, +void getTileableBands(Operation *f, std::vector> *bands); /// Tiles the specified band of perfectly nested loops creating tile-space loops @@ -272,7 +273,7 @@ void mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef processorId, ArrayRef numProcessors); /// Gathers all AffineForOps in 'func.func' grouped by loop depth. -void gatherLoops(func::FuncOp func, +void gatherLoops(Operation* func, std::vector> &depthToLoops); /// Creates an AffineForOp while ensuring that the lower and upper bounds are diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index 96bd3c6a9a7bc..dbca6821c82ef 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -14,19 +14,38 @@ #ifndef MLIR_DIALECT_AFFINE_PASSES_H #define MLIR_DIALECT_AFFINE_PASSES_H +#include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include +#include namespace mlir { namespace func { class FuncOp; } // namespace func +namespace memref { +class MemRefDialect; +} // namespace memref namespace affine { class AffineForOp; + +class AffineScopePassBase : public OperationPass<> { + using OperationPass<>::OperationPass; + + bool canScheduleOn(RegisteredOperationName opInfo) const final { + return opInfo.hasTrait() && + opInfo.getStringRef() != ModuleOp::getOperationName(); + } + + bool shouldImplicitlyNestOn(llvm::StringRef anchorName) const final { + return anchorName == ModuleOp::getOperationName(); + } +}; + /// Fusion mode to attempt. The default mode `Greedy` does both /// producer-consumer and sibling fusion. enum FusionMode { Greedy, ProducerConsumer, Sibling }; @@ -37,40 +56,46 @@ enum FusionMode { Greedy, ProducerConsumer, Sibling }; /// Creates a simplification pass for affine structures (maps and sets). In /// addition, this pass also normalizes memrefs to have the trivial (identity) /// layout map. -std::unique_ptr> +std::unique_ptr createSimplifyAffineStructuresPass(); /// Creates a loop invariant code motion pass that hoists loop invariant /// operations out of affine loops. -std::unique_ptr> +std::unique_ptr createAffineLoopInvariantCodeMotionPass(); +/// Creates a pass to convert all parallel affine.for's into 1-d affine.parallel +/// ops. +std::unique_ptr createAffineParallelizePass(); + +/// Creates a pass that converts some memref operators to affine operators. +std::unique_ptr createRaiseMemrefToAffine(); + /// Apply normalization transformations to affine loop-like ops. If /// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the /// loop is replaced by its loop body). -std::unique_ptr> +std::unique_ptr createAffineLoopNormalizePass(bool promoteSingleIter = false); /// Performs packing (or explicit copying) of accessed memref regions into /// buffers in the specified faster memory space through either pointwise copies /// or DMA operations. -std::unique_ptr> createAffineDataCopyGenerationPass( +std::unique_ptr createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits::max()); /// Overload relying on pass options for initialization. -std::unique_ptr> +std::unique_ptr createAffineDataCopyGenerationPass(); /// Creates a pass to replace affine memref accesses by scalars using store to /// load forwarding and redundant load elimination; consequently also eliminate /// dead allocs. -std::unique_ptr> -createAffineScalarReplacementPass(); +std::unique_ptr createAffineScalarReplacementPass(); /// Creates a pass that transforms perfectly nested loops with independent /// bounds into a single loop. -std::unique_ptr> createLoopCoalescingPass(); +std::unique_ptr createLoopCoalescingPass(); /// Creates a loop fusion pass which fuses affine loop nests at the top-level of /// the operation the pass is created on according to the type of fusion @@ -83,10 +108,10 @@ createLoopFusionPass(unsigned fastMemorySpace = 0, enum FusionMode fusionMode = FusionMode::Greedy); /// Creates a pass to perform tiling on loop nests. -std::unique_ptr> +std::unique_ptr createLoopTilingPass(uint64_t cacheSizeBytes); /// Overload relying on pass options for initialization. -std::unique_ptr> createLoopTilingPass(); +std::unique_ptr createLoopTilingPass(); /// Creates a loop unrolling pass with the provided parameters. /// 'getUnrollFactor' is a function callback for clients to supply a function @@ -94,7 +119,7 @@ std::unique_ptr> createLoopTilingPass(); /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -std::unique_ptr> createLoopUnrollPass( +std::unique_ptr createLoopUnrollPass( int unrollFactor = -1, bool unrollUpToFactor = false, bool unrollFull = false, const std::function &getUnrollFactor = nullptr); @@ -102,12 +127,12 @@ std::unique_ptr> createLoopUnrollPass( /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command /// line if provided. -std::unique_ptr> +std::unique_ptr createLoopUnrollAndJamPass(int unrollJamFactor = -1); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -std::unique_ptr> createPipelineDataTransferPass(); +std::unique_ptr createPipelineDataTransferPass(); /// Creates a pass to expand affine index operations into more fundamental /// operations (not necessarily restricted to Affine dialect). diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index 0b8d5b7d94861..f54c8efc43a70 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -15,7 +15,10 @@ include "mlir/Pass/PassBase.td" -def AffineDataCopyGeneration : Pass<"affine-data-copy-generate", "func::FuncOp"> { +class AffineScopePass + : PassBase; + +def AffineDataCopyGeneration : AffineScopePass<"affine-data-copy-generate"> { let summary = "Generate explicit copying for affine memory operations"; let constructor = "mlir::affine::createAffineDataCopyGenerationPass()"; let dependentDialects = ["memref::MemRefDialect"]; @@ -43,7 +46,7 @@ def AffineDataCopyGeneration : Pass<"affine-data-copy-generate", "func::FuncOp"> ]; } -def AffineLoopFusion : Pass<"affine-loop-fusion"> { +def AffineLoopFusion : AffineScopePass<"affine-loop-fusion"> { let summary = "Fuse affine loop nests"; let description = [{ This pass performs fusion of loop nests using a slicing-based approach. The @@ -178,12 +181,12 @@ def AffineLoopFusion : Pass<"affine-loop-fusion"> { } def AffineLoopInvariantCodeMotion - : Pass<"affine-loop-invariant-code-motion", "func::FuncOp"> { + : AffineScopePass<"affine-loop-invariant-code-motion"> { let summary = "Hoist loop invariant instructions outside of affine loops"; let constructor = "mlir::affine::createAffineLoopInvariantCodeMotionPass()"; } -def AffineLoopTiling : Pass<"affine-loop-tile", "func::FuncOp"> { +def AffineLoopTiling : AffineScopePass<"affine-loop-tile"> { let summary = "Tile affine loop nests"; let constructor = "mlir::affine::createLoopTilingPass()"; let options = [ @@ -199,7 +202,7 @@ def AffineLoopTiling : Pass<"affine-loop-tile", "func::FuncOp"> { ]; } -def AffineLoopUnroll : InterfacePass<"affine-loop-unroll", "FunctionOpInterface"> { +def AffineLoopUnroll : AffineScopePass<"affine-loop-unroll"> { let summary = "Unroll affine loops"; let constructor = "mlir::affine::createLoopUnrollPass()"; let options = [ @@ -219,7 +222,7 @@ def AffineLoopUnroll : InterfacePass<"affine-loop-unroll", "FunctionOpInterface" ]; } -def AffineLoopUnrollAndJam : InterfacePass<"affine-loop-unroll-jam", "FunctionOpInterface"> { +def AffineLoopUnrollAndJam : AffineScopePass<"affine-loop-unroll-jam"> { let summary = "Unroll and jam affine loops"; let constructor = "mlir::affine::createLoopUnrollAndJamPass()"; let options = [ @@ -230,7 +233,7 @@ def AffineLoopUnrollAndJam : InterfacePass<"affine-loop-unroll-jam", "FunctionOp } def AffinePipelineDataTransfer - : Pass<"affine-pipeline-data-transfer", "func::FuncOp"> { + : AffineScopePass<"affine-pipeline-data-transfer"> { let summary = "Pipeline non-blocking data transfers between explicitly " "managed levels of the memory hierarchy"; let description = [{ @@ -298,7 +301,7 @@ def AffinePipelineDataTransfer let constructor = "mlir::affine::createPipelineDataTransferPass()"; } -def AffineScalarReplacement : Pass<"affine-scalrep", "func::FuncOp"> { +def AffineScalarReplacement : AffineScopePass<"affine-scalrep"> { let summary = "Replace affine memref accesses by scalars by forwarding stores " "to loads and eliminating redundant loads"; let description = [{ @@ -344,7 +347,7 @@ def AffineScalarReplacement : Pass<"affine-scalrep", "func::FuncOp"> { let constructor = "mlir::affine::createAffineScalarReplacementPass()"; } -def AffineVectorize : Pass<"affine-super-vectorize", "func::FuncOp"> { +def AffineVectorize : AffineScopePass<"affine-super-vectorize"> { let summary = "Vectorize to a target independent n-D vector abstraction"; let dependentDialects = ["vector::VectorDialect"]; let options = [ @@ -368,7 +371,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "func::FuncOp"> { ]; } -def AffineParallelize : Pass<"affine-parallelize", "func::FuncOp"> { +def AffineParallelize : AffineScopePass<"affine-parallelize"> { let summary = "Convert affine.for ops into 1-D affine.parallel"; let options = [ Option<"maxNested", "max-nested", "unsigned", /*default=*/"-1u", @@ -380,7 +383,7 @@ def AffineParallelize : Pass<"affine-parallelize", "func::FuncOp"> { ]; } -def AffineLoopNormalize : Pass<"affine-loop-normalize", "func::FuncOp"> { +def AffineLoopNormalize : AffineScopePass<"affine-loop-normalize"> { let summary = "Apply normalization transformations to affine loop-like ops"; let constructor = "mlir::affine::createAffineLoopNormalizePass()"; let options = [ @@ -389,14 +392,26 @@ def AffineLoopNormalize : Pass<"affine-loop-normalize", "func::FuncOp"> { ]; } -def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> { +def LoopCoalescing : AffineScopePass<"affine-loop-coalescing"> { let summary = "Coalesce nested loops with independent bounds into a single " "loop"; let constructor = "mlir::affine::createLoopCoalescingPass()"; let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"]; } -def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> { +def RaiseMemrefDialect : AffineScopePass<"affine-raise-from-memref"> { + let summary = "Turn some memref operators to affine operators where supported"; + let description = [{ + Raise memref.load and memref.store to affine.store and affine.load, inferring + the affine map of those operators if needed. This allows passes like --affine-scalrep + to optimize those loads and stores (forwarding them or eliminating them). + They can be turned back to memref dialect ops with --lower-affine. + }]; + let constructor = "mlir::affine::createRaiseMemrefToAffine()"; + let dependentDialects = ["affine::AffineDialect"]; +} + +def SimplifyAffineStructures : AffineScopePass<"affine-simplify-structures"> { let summary = "Simplify affine expressions in maps/sets and normalize " "memrefs"; let constructor = "mlir::affine::createSimplifyAffineStructuresPass()"; diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index ff1900bc8f2eb..93b7af7d24f85 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -17,6 +17,8 @@ #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/LogicalResult.h" #include namespace mlir { @@ -105,7 +107,7 @@ struct VectorizationStrategy { /// Replace affine store and load accesses by scalars by forwarding stores to /// loads and eliminate invariant affine loads; consequently, eliminate dead /// allocs. -void affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, +void affineScalarReplace(Operation *parentOp, DominanceInfo &domInfo, PostDominanceInfo &postDomInfo, AliasAnalysis &analysis); @@ -338,6 +340,20 @@ OpFoldResult linearizeIndex(OpBuilder &builder, Location loc, ArrayRef multiIndex, ArrayRef basis); +/// Given a set of indices into a memref which may be computed using +/// arith ops, try to compute each value to an affine expr. This is +/// only possible if the indices are an expression of valid dims and +/// args. If this succeeds, the affine map is populated, along with +/// the map arguments (concrete bindings for dims and symbols). +LogicalResult +convertValuesToAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, + AffineMap &map, + llvm::SmallVectorImpl &mapArgs); +LogicalResult +convertValuesToAffineMapAndArgs(MLIRContext *ctx, + ArrayRef indices, AffineMap &map, + llvm::SmallVectorImpl &mapArgs); + /// Ensure that all operations that could be executed after `start` /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path /// between the operations) do not have the potential memory effect diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index 493180cd54e5b..429a76c45d092 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -4,6 +4,8 @@ #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" +#include +#include namespace mlir { class FunctionOpInterface; @@ -23,6 +25,20 @@ struct OneShotBufferizationOptions; /// Maps from symbol table to its corresponding dealloc helper function. using DeallocHelperMap = llvm::DenseMap; + +class BufferScopePassBase : public OperationPass<> { + using OperationPass<>::OperationPass; + + bool canScheduleOn(RegisteredOperationName opInfo) const final { + return opInfo.hasTrait() && + opInfo.getStringRef() != ModuleOp::getOperationName(); + } + + bool shouldImplicitlyNestOn(llvm::StringRef anchorName) const final { + return anchorName == ModuleOp::getOperationName(); + } +}; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index 3bbb8b02c644e..972dd2236d672 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -11,8 +11,12 @@ include "mlir/Pass/PassBase.td" -def OwnershipBasedBufferDeallocationPass - : Pass<"ownership-based-buffer-deallocation"> { +class BufferScopePass + : PassBase; + + +def OwnershipBasedBufferDeallocationPass : BufferScopePass< + "ownership-based-buffer-deallocation"> { let summary = "Adds all required dealloc operations for all allocations in " "the input program"; let description = [{ @@ -152,8 +156,8 @@ def OwnershipBasedBufferDeallocationPass ]; } -def BufferDeallocationSimplificationPass - : Pass<"buffer-deallocation-simplification"> { +def BufferDeallocationSimplificationPass : + BufferScopePass<"buffer-deallocation-simplification"> { let summary = "Optimizes `bufferization.dealloc` operation for more " "efficient codegen"; let description = [{ @@ -170,7 +174,7 @@ def BufferDeallocationSimplificationPass } def OptimizeAllocationLivenessPass - : Pass<"optimize-allocation-liveness", "func::FuncOp"> { + : BufferScopePass<"optimize-allocation-liveness"> { let summary = "This pass optimizes the liveness of temp allocations in the " "input function"; let description = [{ @@ -184,7 +188,7 @@ def OptimizeAllocationLivenessPass let dependentDialects = ["mlir::memref::MemRefDialect"]; } -def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> { +def LowerDeallocationsPass : BufferScopePass<"bufferization-lower-deallocations"> { let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`" "operations"; let description = [{ @@ -204,7 +208,7 @@ def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> { ]; } -def BufferHoistingPass : Pass<"buffer-hoisting", "func::FuncOp"> { +def BufferHoistingPass : BufferScopePass<"buffer-hoisting"> { let summary = "Optimizes placement of allocation operations by moving them " "into common dominators and out of nested regions"; let description = [{ @@ -213,7 +217,7 @@ def BufferHoistingPass : Pass<"buffer-hoisting", "func::FuncOp"> { }]; } -def BufferLoopHoistingPass : Pass<"buffer-loop-hoisting", "func::FuncOp"> { +def BufferLoopHoistingPass : BufferScopePass<"buffer-loop-hoisting"> { let summary = "Optimizes placement of allocation operations by moving them " "out of loop nests"; let description = [{ @@ -462,8 +466,7 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> { ]; } -def PromoteBuffersToStackPass - : Pass<"promote-buffers-to-stack", "func::FuncOp"> { +def PromoteBuffersToStackPass : BufferScopePass<"promote-buffers-to-stack"> { let summary = "Promotes heap-based allocations to automatically managed " "stack-based allocations"; let description = [{ @@ -483,7 +486,7 @@ def PromoteBuffersToStackPass ]; } -def EmptyTensorEliminationPass : Pass<"eliminate-empty-tensors"> { +def EmptyTensorEliminationPass : BufferScopePass<"eliminate-empty-tensors"> { let summary = "Try to eliminate all tensor.empty ops."; let description = [{ Try to eliminate "tensor.empty" ops inside `op`. This transformation looks diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e4dd458eaff84..f51cf6b97bb83 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -318,7 +318,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - SameVariadicOperandSize, + AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Reduce operator"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td index 86a2b3c21faf0..d134d1d8acff0 100644 --- a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td @@ -11,7 +11,7 @@ include "mlir/Pass/PassBase.td" -def CheckUsesPass : Pass<"transform-dialect-check-uses"> { +def CheckUsesPass : Pass<"transform-dialect-check-uses", "mlir::ModuleOp"> { let summary = "warn about potential use-after-free in the transform dialect"; let description = [{ This pass analyzes operations from the transform dialect and its extensions @@ -32,7 +32,7 @@ def CheckUsesPass : Pass<"transform-dialect-check-uses"> { }]; } -def InferEffectsPass : Pass<"transform-infer-effects"> { +def InferEffectsPass : Pass<"transform-infer-effects", "mlir::ModuleOp"> { let summary = "infer transform side effects for symbols"; let description = [{ This pass analyzes the definitions of transform dialect callable symbol @@ -42,7 +42,7 @@ def InferEffectsPass : Pass<"transform-infer-effects"> { }]; } -def PreloadLibraryPass : Pass<"transform-preload-library"> { +def PreloadLibraryPass : Pass<"transform-preload-library", "mlir::ModuleOp"> { let summary = "preload transform dialect library"; let description = [{ This pass preloads a transform library and makes it available to subsequent @@ -61,7 +61,7 @@ def PreloadLibraryPass : Pass<"transform-preload-library"> { ]; } -def InterpreterPass : Pass<"transform-interpreter"> { +def InterpreterPass : Pass<"transform-interpreter", "mlir::ModuleOp"> { let summary = "transform dialect interpreter"; let description = [{ This pass runs the transform dialect interpreter and applies the named diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 7725a3a2910bd..6b9eb2bd3d01c 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -14,6 +14,7 @@ #include "mlir/Pass/PassRegistry.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include namespace mlir { @@ -193,6 +194,17 @@ class Pass { /// operations they operate on. virtual bool canScheduleOn(RegisteredOperationName opName) const = 0; + /// Indicate whether this pass should implicitly nest itself in the pass manager, + /// when there is a mismatch between the anchor type and this pass' anchor type. + /// By default passes that have a specific anchor name nest themselves, and passes + /// that can handle any anchor don't. + /// + /// This is only ever called if the PassManager uses implicit nesting. Passes are + /// also never implicitly nested on a pass manager with anchor "any". + virtual bool shouldImplicitlyNestOn(StringRef anchorName) const { + return getOpName() && *getOpName() != anchorName; + } + /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. /// The provided operation must be the current one or one nested below. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index d9bab431e2e0c..4d688a049c194 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -54,7 +54,7 @@ class OpPassManager { Implicit, /// Explicit nesting behavior. This requires that any passes added to this /// pass manager support its operation type. - Explicit + Explicit, }; /// Construct a new op-agnostic ("any") pass manager with the given operation @@ -165,6 +165,13 @@ class OpPassManager { /// Return the current nesting mode. Nesting getNesting(); + + /// Make the pass pipeline fetch its anchors by doing a recursive walk, + /// instead of being anchored on the root of the IR. + void setRecursiveAnchorFetching(bool enabled = true); + + bool hasRecursiveAnchor() const; + private: /// Initialize all of the passes within this pass manager with the given /// initialization generation. The initialization generation is used to detect @@ -222,7 +229,7 @@ using ReproducerStreamFactory = std::function(std::string &error)>; std::string -makeReproducer(StringRef anchorName, +makeReproducer(StringRef anchorName, bool hasRecursiveAnchor, const llvm::iterator_range &passes, Operation *op, StringRef outputFile, bool disableThreads = false, bool verifyPasses = false); diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h index 09bd86b9581df..24270f8f40c0f 100644 --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -14,12 +14,15 @@ #define MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H #include "mlir/Debug/CLOptionsSetup.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Support/ToolUtilities.h" #include "llvm/ADT/StringRef.h" #include #include #include +#include +#include namespace llvm { class raw_ostream; @@ -141,6 +144,18 @@ class MlirOptMainConfig { } bool shouldListPasses() const { return listPassesFlag; } + MlirOptMainConfig& setPassPipelineAnchor(std::string&& name) { + passPipelineAnchorFlag = std::move(name); + return *this; + } + + std::optional getPassPipelineAnchor() const { + if (passPipelineAnchorFlag.empty()) { + return std::nullopt; + } + return passPipelineAnchorFlag; + } + /// Enable running the reproducer information stored in resources (if /// present). MlirOptMainConfig &runReproducer(bool enableReproducer) { @@ -274,6 +289,10 @@ class MlirOptMainConfig { /// Merge output chunks into one file using the given marker. std::string outputSplitMarkerFlag = ""; + /// Specify an operation name as the anchor for the CLI pass pipeline. + /// By default the pipeline is anchored on the root of the IR. + std::string passPipelineAnchorFlag = ""; + /// Use an explicit top-level module op during parsing. bool useExplicitModuleFlag = false; diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index a39ab77fc8fb3..b7cf19bfb790a 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -273,7 +273,7 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> { let constructor = "mlir::createGenerateRuntimeVerificationPass()"; } -def Inliner : Pass<"inline"> { +def Inliner : Pass<"inline", "mlir::ModuleOp"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; let options = [ diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index e805e21d878bf..ca14d58ab66d7 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -341,15 +341,20 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { /// constant value lattices are uninitialized, return std::nullopt to indicate /// the analysis should bail out. static std::optional> getOperandValuesImpl( - Operation *op, + Operation *op, bool failIfAnyNull, function_ref *(Value)> getLattice) { SmallVector operands; operands.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { const Lattice *cv = getLattice(operand); // If any of the operands' values are uninitialized, bail out. - if (cv->getValue().isUninitialized()) - return {}; + if (cv->getValue().isUninitialized()) { + if (failIfAnyNull) + return {}; + operands.emplace_back(); + continue; + } + operands.push_back(cv->getValue().getConstantValue()); } return operands; @@ -357,7 +362,16 @@ static std::optional> getOperandValuesImpl( std::optional> DeadCodeAnalysis::getOperandValues(Operation *op) { - return getOperandValuesImpl(op, [&](Value value) { + return getOperandValuesImpl(op, true, [&](Value value) { + auto *lattice = getOrCreate>(value); + lattice->useDefSubscribe(this); + return lattice; + }); +} + +SmallVector +DeadCodeAnalysis::getOperandValuesBestEffort(Operation *op) { + return *getOperandValuesImpl(op, false, [&](Value value) { auto *lattice = getOrCreate>(value); lattice->useDefSubscribe(this); return lattice; @@ -366,11 +380,9 @@ DeadCodeAnalysis::getOperandValues(Operation *op) { void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { // Try to deduce a single successor for the branch. - std::optional> operands = getOperandValues(branch); - if (!operands) - return; + SmallVector operands = getOperandValuesBestEffort(branch); - if (Block *successor = branch.getSuccessorForOperands(*operands)) { + if (Block *successor = branch.getSuccessorForOperands(operands)) { markEdgeLive(branch->getBlock(), successor); } else { // Otherwise, mark all successors as executable and outgoing edges. @@ -382,12 +394,10 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { void DeadCodeAnalysis::visitRegionBranchOperation( RegionBranchOpInterface branch) { // Try to deduce which regions are executable. - std::optional> operands = getOperandValues(branch); - if (!operands) - return; + SmallVector operands = getOperandValuesBestEffort(branch); SmallVector successors; - branch.getEntrySuccessorRegions(*operands, successors); + branch.getEntrySuccessorRegions(operands, successors); for (const RegionSuccessor &successor : successors) { // The successor can be either an entry block or the parent operation. ProgramPoint *point = diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 29f57c602f9cb..cfa880e658f74 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -44,9 +44,9 @@ void AnalysisState::addDependency(ProgramPoint *dependent, (void)inserted; DATAFLOW_DEBUG({ if (inserted) { - llvm::dbgs() << "Creating dependency between " << debugName << " of " - << anchor << "\nand " << debugName << " on " << dependent - << "\n"; + llvm::dbgs() << "Creating dependency between \t" << debugName << " of " + << anchor << "\n and\t" << debugName + << " of " << *dependent << "\n"; } }); } @@ -125,7 +125,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { worklist.pop(); DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName - << "' on: " << point << "\n"); + << "' on: " << *point << "\n"); if (failed(analysis->visit(point))) return failure(); } diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp index e3245d68b3699..c1673b920ac23 100644 --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -369,7 +369,7 @@ Operation *LivenessBlockInfo::getStartOperation(Value value) const { Operation *definingOp = value.getDefiningOp(); // The given value is either live-in or is defined // in the scope of this block. - if (isLiveIn(value) || !definingOp) + if (!definingOp || definingOp->getBlock() != block) return &block->front(); return definingOp; } diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt index 5123c2a7cf916..8ad221387ef99 100644 --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -19,7 +19,6 @@ add_mlir_conversion_library(MLIRTosaToLinalg MLIRLinalgDialect MLIRLinalgUtils MLIRMathDialect - MLIRPass MLIRSupport MLIRTensorDialect MLIRTosaDialect diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp index 84b76d33c3e67..5eaa9d5eccd11 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -65,6 +65,8 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos, [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; }) .Case( [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; }) + .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; }) + .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; }) .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; }) .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; }) .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; }) @@ -491,6 +493,14 @@ LogicalResult MemRefAccess::getAccessRelation(IntegerRelation &rel) const { IntegerRelation domainRel = domain; if (rel.getSpace().isUsingIds() && !domainRel.getSpace().isUsingIds()) domainRel.resetIds(); + + if (!rel.getSpace().isUsingIds()) { + assert(rel.getNumVars() == 0); + rel.resetIds(); + if (!domainRel.getSpace().isUsingIds()) + domainRel.resetIds(); + } + domainRel.appendVar(VarKind::Range, accessValueMap.getNumResults()); domainRel.mergeAndAlignSymbols(rel); domainRel.mergeLocalVars(rel); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 8acb21d5074b4..984fed19dd788 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" @@ -294,10 +295,12 @@ bool mlir::affine::isValidDim(Value value) { return isValidDim(value, getAffineScope(defOp)); // This value has to be a block argument for an op that has the - // `AffineScope` trait or for an affine.for or affine.parallel. + // `AffineScope` trait or an induction var of an affine.for or + // affine.parallel. + if (isAffineInductionVar(value)) + return true; auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return parentOp && (parentOp->hasTrait() || - isa(parentOp)); + return parentOp && parentOp->hasTrait(); } // Value can be used as a dimension id iff it meets one of the following @@ -316,10 +319,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) { auto *op = value.getDefiningOp(); if (!op) { - // This value has to be a block argument for an affine.for or an + // This value has to be an induction var for an affine.for or an // affine.parallel. - auto *parentOp = llvm::cast(value).getOwner()->getParentOp(); - return isa(parentOp); + return isAffineInductionVar(value); } // Affine apply operation is ok if all of its operands are ok. @@ -2482,6 +2484,10 @@ bool AffineForOp::matchingBoundOperandList() { SmallVector AffineForOp::getLoopRegions() { return {&getRegion()}; } +std::optional AffineForOp::getLoopResults() { + return {getResults()}; +} + std::optional> AffineForOp::getLoopInductionVars() { return SmallVector{getInductionVar()}; } @@ -3889,8 +3895,9 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType, case arith::AtomicRMWKind::muli: return isa(resultType); case arith::AtomicRMWKind::maximumf: - return isa(resultType); case arith::AtomicRMWKind::minimumf: + case arith::AtomicRMWKind::maxnumf: + case arith::AtomicRMWKind::minnumf: return isa(resultType); case arith::AtomicRMWKind::maxs: { auto intType = llvm::dyn_cast(resultType); diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index 4d30213cc6ec2..2ed1393cd0a41 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/CommandLine.h" @@ -86,8 +87,8 @@ struct AffineDataCopyGeneration /// Generates copies for memref's living in 'slowMemorySpace' into newly created /// buffers in 'fastMemorySpace', and replaces memory operations to the former -/// by the latter. -std::unique_ptr> +/// by the latter. Only load op's handled for now. +std::unique_ptr mlir::affine::createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace, int minDmaTransferSize, uint64_t fastMemCapacityBytes) { @@ -95,7 +96,7 @@ mlir::affine::createAffineDataCopyGenerationPass( slowMemorySpace, fastMemorySpace, tagMemorySpace, minDmaTransferSize, fastMemCapacityBytes); } -std::unique_ptr> +std::unique_ptr mlir::affine::createAffineDataCopyGenerationPass() { return std::make_unique(); } @@ -203,9 +204,9 @@ void AffineDataCopyGeneration::runOnBlock(Block *block, } void AffineDataCopyGeneration::runOnOperation() { - func::FuncOp f = getOperation(); - OpBuilder topBuilder(f.getBody()); - zeroIndex = topBuilder.create(f.getLoc(), 0); + Operation* f = getOperation(); + OpBuilder topBuilder(f->getRegion(0)); + zeroIndex = topBuilder.create(f->getLoc(), 0); // Nests that are copy-in's or copy-out's; the root AffineForOps of those // nests are stored herein. @@ -214,7 +215,7 @@ void AffineDataCopyGeneration::runOnOperation() { // Clear recorded copy nests. copyNests.clear(); - for (auto &block : f) + for (auto &block : f->getRegion(0)) runOnBlock(&block, copyNests); // Promote any single iteration loops in the copy nests and collect diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp index e3f316443161f..8f59a2fe54a76 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -200,10 +200,10 @@ void LoopInvariantCodeMotion::runOnOperation() { // Walk through all loops in a function in innermost-loop-first order. This // way, we first LICM from the inner loop, and place the ops in // the outer loop, which in turn can be further LICM'ed. - getOperation().walk([&](AffineForOp op) { runOnAffineForOp(op); }); + getOperation()->walk([&](AffineForOp op) { runOnAffineForOp(op); }); } -std::unique_ptr> +std::unique_ptr mlir::affine::createAffineLoopInvariantCodeMotionPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp index 5cc38f7051726..be295d9f25b22 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp @@ -38,7 +38,7 @@ struct AffineLoopNormalizePass } void runOnOperation() override { - getOperation().walk([&](Operation *op) { + getOperation()->walk([&](Operation *op) { if (auto affineParallel = dyn_cast(op)) normalizeAffineParallel(affineParallel); else if (auto affineFor = dyn_cast(op)) @@ -49,7 +49,7 @@ struct AffineLoopNormalizePass } // namespace -std::unique_ptr> +std::unique_ptr mlir::affine::createAffineLoopNormalizePass(bool promoteSingleIter) { return std::make_unique(promoteSingleIter); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp index fa0676b206826..de5b55db3fe36 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp @@ -60,12 +60,12 @@ struct ParallelizationCandidate { } // namespace void AffineParallelize::runOnOperation() { - func::FuncOp f = getOperation(); + Operation* f = getOperation(); // The walker proceeds in pre-order to process the outer loops first // and control the number of outer parallel loops. std::vector parallelizableLoops; - f.walk([&](AffineForOp loop) { + f->walk([&](AffineForOp loop) { SmallVector reductions; if (isLoopParallel(loop, parallelReductions ? &reductions : nullptr)) parallelizableLoops.emplace_back(loop, std::move(reductions)); @@ -92,3 +92,8 @@ void AffineParallelize::runOnOperation() { } } } + +std::unique_ptr +mlir::affine::createAffineParallelizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp index 16c0d3d8be3dc..fe0a668a59b68 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp @@ -16,8 +16,8 @@ #include "mlir/Analysis/AliasAnalysis.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" -#include namespace mlir { namespace affine { @@ -40,13 +40,13 @@ struct AffineScalarReplacement } // namespace -std::unique_ptr> +std::unique_ptr mlir::affine::createAffineScalarReplacementPass() { return std::make_unique(); } void AffineScalarReplacement::runOnOperation() { - affineScalarReplace(getOperation(), getAnalysis(), - getAnalysis(), - getAnalysis()); + affineScalarReplace( + getOperation(), getAnalysis(), + getAnalysis(), getAnalysis()); } diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index c42789b01bc9f..1c82822b2bd7f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms LoopUnroll.cpp LoopUnrollAndJam.cpp PipelineDataTransfer.cpp + RaiseMemrefDialect.cpp ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp index 05c77070a70c1..4006577cdf05e 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp @@ -36,8 +36,7 @@ struct LoopCoalescingPass : public affine::impl::LoopCoalescingBase { void runOnOperation() override { - func::FuncOp func = getOperation(); - func.walk([](Operation *op) { + getOperation()->walk([](Operation *op) { if (auto scfForOp = dyn_cast(op)) (void)coalescePerfectlyNestedSCFForLoops(scfForOp); else if (auto affineForOp = dyn_cast(op)) @@ -48,7 +47,7 @@ struct LoopCoalescingPass } // namespace -std::unique_ptr> +std::unique_ptr mlir::affine::createLoopCoalescingPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp index c8400dfe8cd5c..628816d1d1eeb 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -64,11 +64,11 @@ struct LoopTiling : public affine::impl::AffineLoopTilingBase { /// Creates a pass to perform loop tiling on all suitable loop nests of a /// Function. -std::unique_ptr> +std::unique_ptr mlir::affine::createLoopTilingPass(uint64_t cacheSizeBytes) { return std::make_unique(cacheSizeBytes); } -std::unique_ptr> +std::unique_ptr mlir::affine::createLoopTilingPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp index 7ff77968c61ad..1c95331a27841 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -82,27 +83,25 @@ static bool isInnermostAffineForOp(AffineForOp op) { } /// Gathers loops that have no affine.for's nested within. -static void gatherInnermostLoops(FunctionOpInterface f, +static void gatherInnermostLoops(Operation* f, SmallVectorImpl &loops) { - f.walk([&](AffineForOp forOp) { + f->walk([&](AffineForOp forOp) { if (isInnermostAffineForOp(forOp)) loops.push_back(forOp); }); } void LoopUnroll::runOnOperation() { - FunctionOpInterface func = getOperation(); - if (func.isExternal()) - return; + Operation* func = getOperation(); if (unrollFull && unrollFullThreshold.hasValue()) { // Store short loops as we walk. SmallVector loops; // Gathers all loops with trip count <= minTripCount. Do a post order walk - // so that loops are gathered from innermost to outermost (or else - // unrolling an outer one may delete gathered inner ones). - getOperation().walk([&](AffineForOp forOp) { + // so that loops are gathered from innermost to outermost (or else unrolling + // an outer one may delete gathered inner ones). + getOperation()->walk([&](AffineForOp forOp) { std::optional tripCount = getConstantTripCount(forOp); if (tripCount && *tripCount <= unrollFullThreshold) loops.push_back(forOp); @@ -145,8 +144,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { cleanUpUnroll); } -std::unique_ptr> -mlir::affine::createLoopUnrollPass( +std::unique_ptr mlir::affine::createLoopUnrollPass( int unrollFactor, bool unrollUpToFactor, bool unrollFull, const std::function &getUnrollFactor) { return std::make_unique( diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp index 13640f085951e..442fbf66fefef 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp @@ -75,7 +75,7 @@ struct LoopUnrollAndJam }; } // namespace -std::unique_ptr> +std::unique_ptr mlir::affine::createLoopUnrollAndJamPass(int unrollJamFactor) { return std::make_unique( unrollJamFactor == -1 ? std::nullopt @@ -83,13 +83,11 @@ mlir::affine::createLoopUnrollAndJamPass(int unrollJamFactor) { } void LoopUnrollAndJam::runOnOperation() { - if (getOperation().isExternal()) - return; // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. - auto &entryBlock = getOperation().front(); + auto &entryBlock = getOperation()->getRegion(0).front(); if (auto forOp = dyn_cast(entryBlock.front())) (void)loopUnrollJamByFactor(forOp, unrollJamFactor); } diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index 4be99aa197380..4199c4d10aaff 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -52,7 +52,7 @@ struct PipelineDataTransfer /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -std::unique_ptr> +std::unique_ptr mlir::affine::createPipelineDataTransferPass() { return std::make_unique(); } @@ -142,7 +142,7 @@ void PipelineDataTransfer::runOnOperation() { // gets deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); }); + getOperation()->walk([&](AffineForOp forOp) { forOps.push_back(forOp); }); for (auto forOp : forOps) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp new file mode 100644 index 0000000000000..21f5a9d6aa7d7 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp @@ -0,0 +1,77 @@ +//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functionality to convert memref load and store ops to +// the corresponding affine ops, inferring the affine map as needed. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Operation.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace affine { +#define GEN_PASS_DEF_RAISEMEMREFDIALECT +#include "mlir/Dialect/Affine/Passes.h.inc" +} // namespace affine +} // namespace mlir + +#define DEBUG_TYPE "raise-memref-to-affine" + +using namespace mlir; +using namespace mlir::affine; + +namespace { + +struct RaiseMemrefDialect + : public affine::impl::RaiseMemrefDialectBase { + + void runOnOperation() override { + auto *ctx = &getContext(); + Operation *op = getOperation(); + IRRewriter rewriter(ctx); + AffineMap map; + SmallVector mapArgs; + op->walk([&](Operation *op) { + rewriter.setInsertionPoint(op); + if (auto store = llvm::dyn_cast_or_null(op)) { + + if (succeeded(affine::convertValuesToAffineMapAndArgs( + ctx, store.getIndices(), map, mapArgs))) { + rewriter.replaceOpWithNewOp( + op, store.getValueToStore(), store.getMemRef(), map, mapArgs); + return; + } + + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + + } else if (auto load = llvm::dyn_cast_or_null(op)) { + if (succeeded(affine::convertValuesToAffineMapAndArgs( + ctx, load.getIndices(), map, mapArgs))) { + rewriter.replaceOpWithNewOp(op, load.getMemRef(), map, + mapArgs); + return; + } + LLVM_DEBUG(llvm::dbgs() + << "[affine] Cannot raise memref op: " << op << "\n"); + } + }); + } +}; + +} // namespace + +std::unique_ptr mlir::affine::createRaiseMemrefToAffine() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index 31711ade3153b..337cfc8308439 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Operation.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -81,24 +82,24 @@ struct SimplifyAffineStructures } // namespace -std::unique_ptr> +std::unique_ptr mlir::affine::createSimplifyAffineStructuresPass() { return std::make_unique(); } void SimplifyAffineStructures::runOnOperation() { - auto func = getOperation(); + Operation *func = getOperation(); simplifiedAttributes.clear(); - RewritePatternSet patterns(func.getContext()); - AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); - AffineForOp::getCanonicalizationPatterns(patterns, func.getContext()); - AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext()); + RewritePatternSet patterns(func->getContext()); + AffineApplyOp::getCanonicalizationPatterns(patterns, func->getContext()); + AffineForOp::getCanonicalizationPatterns(patterns, func->getContext()); + AffineIfOp::getCanonicalizationPatterns(patterns, func->getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); // The simplification of affine attributes will likely simplify the op. Try to // fold/apply canonicalization patterns when we have affine dialect ops. SmallVector opsToSimplify; - func.walk([&](Operation *op) { + func->walk([&](Operation *op) { for (auto attr : op->getAttrs()) { if (auto mapAttr = dyn_cast(attr.getValue())) simplifyAndUpdateAttribute(op, attr.getName(), mapAttr); diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index eaaafaf68767e..f78651d27f735 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -1748,21 +1749,21 @@ static void vectorizeLoops(Operation *parentOp, DenseSet &loops, /// Applies vectorization to the current function by searching over a bunch of /// predetermined patterns. void Vectorize::runOnOperation() { - func::FuncOp f = getOperation(); + Operation* f = getOperation(); if (!fastestVaryingPattern.empty() && fastestVaryingPattern.size() != vectorSizes.size()) { - f.emitRemark("Fastest varying pattern specified with different size than " + f->emitRemark("Fastest varying pattern specified with different size than " "the vector size."); return signalPassFailure(); } if (vectorizeReductions && vectorSizes.size() != 1) { - f.emitError("Vectorizing reductions is supported only for 1-D vectors."); + f->emitError("Vectorizing reductions is supported only for 1-D vectors."); return signalPassFailure(); } if (llvm::any_of(vectorSizes, [](int64_t size) { return size <= 0; })) { - f.emitError("Vectorization factor must be greater than zero."); + f->emitError("Vectorization factor must be greater than zero."); return signalPassFailure(); } @@ -1772,7 +1773,7 @@ void Vectorize::runOnOperation() { // If 'vectorize-reduction=true' is provided, we also populate the // `reductionLoops` map. if (vectorizeReductions) { - f.walk([¶llelLoops, &reductionLoops](AffineForOp loop) { + f->walk([¶llelLoops, &reductionLoops](AffineForOp loop) { SmallVector reductions; if (isLoopParallel(loop, &reductions)) { parallelLoops.insert(loop); @@ -1782,7 +1783,7 @@ void Vectorize::runOnOperation() { } }); } else { - f.walk([¶llelLoops](AffineForOp loop) { + f->walk([¶llelLoops](AffineForOp loop) { if (isLoopParallel(loop)) parallelLoops.insert(loop); }); diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 5c94ec2985c3d..bcd7db3d21052 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -871,10 +871,10 @@ void mlir::affine::getPerfectlyNestedLoops( /// a temporary placeholder to test the mechanics of tiled code generation. /// Returns all maximal outermost perfect loop nests to tile. void mlir::affine::getTileableBands( - func::FuncOp f, std::vector> *bands) { + Operation* f, std::vector> *bands) { // Get maximal perfect nest of 'affine.for' insts starting from root // (inclusive). - for (AffineForOp forOp : f.getOps()) { + for (AffineForOp forOp : f->getRegion(0).getOps()) { SmallVector band; getPerfectlyNestedLoops(band, forOp); bands->push_back(band); @@ -2543,8 +2543,8 @@ gatherLoopsInBlock(Block *block, unsigned currLoopDepth, /// Gathers all AffineForOps in 'func.func' grouped by loop depth. void mlir::affine::gatherLoops( - func::FuncOp func, std::vector> &depthToLoops) { - for (auto &block : func) + Operation* func, std::vector> &depthToLoops) { + for (auto &block : func->getRegion(0)) gatherLoopsInBlock(&block, /*currLoopDepth=*/0, depthToLoops); // Remove last loop level from output since it's empty. diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 2723cff6900d0..791a9587a6f36 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -21,14 +21,22 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include +#include #define DEBUG_TYPE "affine-utils" @@ -890,6 +898,8 @@ static void forwardStoreToLoad( // loads and stores. if (storeVal.getType() != loadOp.getValue().getType()) return; + LLVM_DEBUG(llvm::dbgs() << "Erased load (forwarded from store): " << loadOp + << "\n"); loadOp.getValue().replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); @@ -944,11 +954,124 @@ static void findUnusedStore(AffineWriteOpInterface writeA, mayAlias)) continue; + LLVM_DEBUG(llvm::dbgs() << "Erased store (unused): " << writeA << "\n"); opsToErase.push_back(writeA); break; } } +/// This attempts to find load-store pairs in the body of the loop +/// that could be replaced by an iter_args variable on the loop. The +/// initial load and the final store are moved out of the loop. For +/// such a pair to be eligible: +/// 1. the load must be followed by the store +/// 2. the memref must not be read again after the store +/// 3. the indices of the load and store must match AND be +/// loop-invariant for the given loop. +/// +/// This is a useful transformation as +/// - it exposes reduction dependencies that can be extracted by +/// --affine-parallelize +/// - it is a common pattern in code lowered from linalg. +/// - it exposes more opportunities for forwarding of load/store by +/// moving the load/store out of the loop and into a scope. +/// +static bool findReductionVariablesAndRewrite( + LoopLikeOpInterface loop, PostDominanceInfo &postDominanceInfo, + llvm::function_ref mayAlias) { + if (!loop.getLoopResults()) + return false; + + SmallVector> result; + auto *region = loop.getLoopRegions()[0]; + auto &block = region->front(); + + for (auto &op : block.without_terminator()) { + // iterate over ops to find loop-invariant load/store pairs + auto asLoad = dyn_cast(op); + if (!asLoad) { + continue; + } + + // Indices must be loop-invariant + bool isLoopInvariant = true; + for (auto operand : asLoad.getMapOperands()) { + if (!loop.isDefinedOutsideOfLoop(operand)) { + isLoopInvariant = false; + break; + } + } + if (!isLoopInvariant) + continue; + + // find a corresponding store + for (auto *user : asLoad.getMemRef().getUsers()) { + if (user->getBlock() != &block || user->isBeforeInBlock(&op)) + continue; + auto asStore = dyn_cast(user); + if (!asStore) + continue; + + // both load and store must access the same index + if (MemRefAccess(asLoad) != MemRefAccess(asStore)) { + break; + } + + // Check that nobody could be reading from the store before the next load, + // as we want to eliminate the store. + if (!affine::hasNoInterveningEffect( + asStore.getOperation(), asLoad, mayAlias)) + break; + + // now let's just replace this pair of accesses with loop iter args + result.push_back({asLoad, asStore}); + } + } + if (result.empty()) + return false; + SmallVector newInitOperands; + SmallVector newYieldOperands; + IRRewriter rewriter(loop->getContext()); + rewriter.startOpModification(loop->getParentOp()); + rewriter.setInsertionPoint(loop); + for (auto [load, store] : result) { + auto rewrittenLoad = cast(rewriter.clone(*load)); + newInitOperands.push_back(rewrittenLoad.getValue()); + newYieldOperands.push_back(store.getValueToStore()); + } + + const auto numResults = loop.getLoopResults()->size(); + auto rewritten = loop.replaceWithAdditionalYields( + rewriter, newInitOperands, false, + [&](OpBuilder &b, Location loc, ArrayRef newBbArgs) { + return newYieldOperands; + }); + if (failed(rewritten)) { + rewriter.cancelOpModification(loop->getParentOp()); + return false; + } + auto newLoop = *rewritten; + + rewriter.setInsertionPointAfter(newLoop); + Operation *next = newLoop; + for (auto [loadStore, bbArg, loopRes] : + llvm::zip(result, rewritten->getRegionIterArgs().drop_front(numResults), + rewritten->getLoopResults()->drop_front(numResults))) { + auto load = loadStore.first; + rewriter.replaceOp(load, bbArg); + + auto store = loadStore.second; + rewriter.moveOpAfter(store, next); + store.getValueToStoreMutable().set(loopRes); + next = store; + } + + rewriter.finalizeOpModification(newLoop->getParentOp()); + LLVM_DEBUG(llvm::dbgs() << "Replaced loop reduction variable: \n" + << newLoop << "\n"); + return true; +} + // The load to load forwarding / redundant load elimination is similar to the // store to load forwarding. // loadA will be be replaced with loadB if: @@ -1036,21 +1159,55 @@ static void loadCSE(AffineReadOpInterface loadA, // currently only eliminates the stores only if no other loads/uses (other // than dealloc) remain. // -void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, +void doForwarding(Operation *parentOp, DominanceInfo &domInfo, + PostDominanceInfo &postDomInfo, + llvm::function_ref mayAlias); + +void mlir::affine::affineScalarReplace(Operation *parentOp, + DominanceInfo &domInfo, PostDominanceInfo &postDomInfo, AliasAnalysis &aliasAnalysis) { + + auto mayAlias = [&](Value val1, Value val2) -> bool { + return !aliasAnalysis.alias(val1, val2).isNo(); + }; + + bool continueWalk; + do { + continueWalk = false; + + // Walk loops and rewrite reduction variables. Once a loop has been + // rewritten, we need to perform forwarding to eliminate the new store and + // loads introduced before and after the new loop. Then we need to continue + // doing that loop by loop. + parentOp->walk([&](AffineForOp loop) { + Operation *loopParent = loop->getParentOp(); + bool rewritten = + findReductionVariablesAndRewrite(loop, postDomInfo, mayAlias); + if (rewritten && loopParent != parentOp) { + doForwarding(loopParent, domInfo, postDomInfo, mayAlias); + continueWalk = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } while (continueWalk); + + // cleanup the parent + doForwarding(parentOp, domInfo, postDomInfo, mayAlias); +} + +void doForwarding(Operation *parentOp, DominanceInfo &domInfo, + PostDominanceInfo &postDomInfo, + llvm::function_ref mayAlias) { // Load op's whose results were replaced by those forwarded from stores. SmallVector opsToErase; // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; - auto mayAlias = [&](Value val1, Value val2) -> bool { - return !aliasAnalysis.alias(val1, val2).isNo(); - }; - // Walk all load's and perform store to load forwarding. - f.walk([&](AffineReadOpInterface loadOp) { + parentOp->walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias); }); for (auto *op : opsToErase) @@ -1058,7 +1215,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, opsToErase.clear(); // Walk all store's and perform unused store elimination - f.walk([&](AffineWriteOpInterface storeOp) { + parentOp->walk([&](AffineWriteOpInterface storeOp) { findUnusedStore(storeOp, opsToErase, postDomInfo, mayAlias); }); for (auto *op : opsToErase) @@ -1091,7 +1248,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, // To eliminate as many loads as possible, run load CSE after eliminating // stores. Otherwise, some stores are wrongly seen as having an intervening // effect. - f.walk([&](AffineReadOpInterface loadOp) { + parentOp->walk([&](AffineReadOpInterface loadOp) { loadCSE(loadOp, opsToErase, domInfo, mayAlias); }); for (auto *op : opsToErase) @@ -2050,3 +2207,136 @@ OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc, return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr, multiIndexAndStrides); } + +namespace { + +/// Find the index of the given value in the `dims` list, +/// and append it if it was not already in the list. The +/// dims list is a list of symbols or dimensions of the +/// affine map. Within the results of an affine map, they +/// are identified by their index, which is why we need +/// this function. +static std::optional +findInListOrAdd(Value value, llvm::SmallVectorImpl &dims, + function_ref isValidElement) { + + Value *loopIV = std::find(dims.begin(), dims.end(), value); + if (loopIV != dims.end()) { + // We found an IV that already has an index, return that index. + return {std::distance(dims.begin(), loopIV)}; + } + if (isValidElement(value)) { + // This is a valid element for the dim/symbol list, push this as a + // parameter. + size_t idx = dims.size(); + dims.push_back(value); + return idx; + } + return std::nullopt; +} + +/// Convert a value to an affine expr if possible. Adds dims and symbols +/// if needed. +static AffineExpr toAffineExpr(Value value, + llvm::SmallVectorImpl &affineDims, + llvm::SmallVectorImpl &affineSymbols) { + using namespace matchers; + IntegerAttr::ValueType cst; + if (matchPattern(value, m_ConstantInt(&cst))) { + return getAffineConstantExpr(cst.getSExtValue(), value.getContext()); + } + Value lhs; + Value rhs; + if (matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs))) || + matchPattern(value, m_Op(m_Any(&lhs), m_Any(&rhs)))) { + AffineExpr lhsE; + AffineExpr rhsE; + if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) && + (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) { + AffineExprKind kind; + if (isa(value.getDefiningOp())) { + kind = mlir::AffineExprKind::Add; + } else { + kind = mlir::AffineExprKind::Mul; + } + return getAffineBinaryOpExpr(kind, lhsE, rhsE); + } + } + + if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) { + return affine::isValidSymbol(v); + })) { + return getAffineSymbolExpr(*dimIx, value.getContext()); + } + + if (auto dimIx = findInListOrAdd( + value, affineDims, [](Value v) { return affine::isValidDim(v); })) { + + return getAffineDimExpr(*dimIx, value.getContext()); + } + + return {}; +} + +} // namespace + +LogicalResult mlir::affine::convertValuesToAffineMapAndArgs( + MLIRContext *ctx, ValueRange indices, AffineMap &map, + llvm::SmallVectorImpl &mapArgs) { + SmallVector results; + SmallVector symbols; + SmallVector dims; + + for (Value indexExpr : indices) { + AffineExpr res = toAffineExpr(indexExpr, dims, symbols); + if (!res) { + return failure(); + } + results.push_back(res); + } + + map = AffineMap::get(dims.size(), symbols.size(), results, ctx); + + dims.append(symbols); + mapArgs.swap(dims); + return success(); +} + +LogicalResult mlir::affine::convertValuesToAffineMapAndArgs( + MLIRContext *ctx, ArrayRef indices, AffineMap &map, + llvm::SmallVectorImpl &mapArgs) { + SmallVector results; + SmallVector symbols; + SmallVector dims; + SmallVector constantSymbols; + + for (OpFoldResult indexExpr : indices) { + if (auto asValue = llvm::dyn_cast_or_null(indexExpr)) { + AffineExpr res = toAffineExpr(asValue, dims, symbols); + if (!res) { + return failure(); + } + results.push_back(res); + } else { + constantSymbols.push_back(indexExpr); + results.push_back(getAffineSymbolExpr(symbols.size(), ctx)); + // add a null symbol here to increment the next symbol id. + symbols.emplace_back(); + } + } + + map = AffineMap::get(dims.size(), symbols.size(), results, ctx); + + for (auto dim : dims) { + mapArgs.push_back(dim); + } + unsigned nextConstSymbol = 0; + for (auto symbol : symbols) { + if (symbol) { + mapArgs.push_back(symbol); + } else { + mapArgs.push_back(constantSymbols[nextConstSymbol++]); + } + } + return success(); +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp index b74df4ff6060f..91c367ad8ca85 100644 --- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp @@ -8,7 +8,10 @@ #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; @@ -100,11 +103,23 @@ void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( resolvedIndices.clear(); for (auto [offset, index, stride] : llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) { - AffineExpr off, idx, str; - bindSymbols(rewriter.getContext(), off, idx, str); - OpFoldResult ofr = makeComposedFoldedAffineApply( - rewriter, loc, AffineMap::get(0, 3, off + idx * str), - {offset, index, stride}); + AffineMap map; + SmallVector mapArgs; + auto *ctx = rewriter.getContext(); + if (failed(affine::convertValuesToAffineMapAndArgs( + ctx, {offset, index, stride}, map, mapArgs))) { + // todo + resolvedIndices.push_back(Value{}); + continue; + } + AffineExpr off, ix, str; + bindDims(ctx, off, ix, str); + auto nextMap = AffineMap::get(3, 0, off + ix * str); + auto composedMap = nextMap.compose(map); + + OpFoldResult ofr = + makeComposedFoldedAffineApply(rewriter, loc, composedMap, mapArgs); + resolvedIndices.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp index 5178d4a62f374..0533358d11abc 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp @@ -95,14 +95,11 @@ struct OptimizeAllocationLiveness OptimizeAllocationLiveness() = default; void runOnOperation() override { - func::FuncOp func = getOperation(); - - if (func.isExternal()) - return; + Operation* func = getOperation(); BufferViewFlowAnalysis analysis = BufferViewFlowAnalysis(func); - func.walk([&](MemoryEffectOpInterface memEffectOp) -> WalkResult { + func->walk([&](MemoryEffectOpInterface memEffectOp) -> WalkResult { if (!hasMemoryAllocEffect(memEffectOp)) return WalkResult::advance(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 07b19e5cb1a89..cd4f9d20730c6 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1354,11 +1354,12 @@ LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl &) { static ParseResult parseDstStyleOp( OpAsmParser &parser, OperationState &result, function_ref parseAttrsFn = - nullptr) { + nullptr, + bool addOperandSegmentSizes = false) { // Parse `ins` and `outs`. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, - /*addOperandSegmentSizes=*/false)) + addOperandSegmentSizes)) return failure(); // Add result types. @@ -1707,9 +1708,12 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { } if (parseDstStyleOp( - parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + parser, result, + [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); - })) + }, + /*addOperandSegmentSizes=*/true)) + return failure(); if (payloadOpName.has_value()) { @@ -1744,7 +1748,9 @@ void ReduceOp::print(OpAsmPrinter &p) { printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); - p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); + p.printOptionalAttrDict( + (*this)->getAttrs(), + {getDimensionsAttrName(), getOperandSegmentSizesAttrName()}); if (!payloadOp) { // Print region if the payload op was not detected. p.increaseIndent(); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index f630c48cdcaa1..d54590acf09e7 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -309,7 +310,7 @@ struct MemRefDestructurableTypeExternalModel constexpr int64_t maxMemrefSizeForDestructuring = 16; if (!memrefType.hasStaticShape() || memrefType.getNumElements() > maxMemrefSizeForDestructuring || - memrefType.getNumElements() == 1) + memrefType.getShape().empty()) return {}; DenseMap destructured; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 11597505e7888..74293241fbf0c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -11,18 +11,27 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" +#include +#include using namespace mlir; using namespace mlir::memref; @@ -850,11 +859,40 @@ struct FoldEmptyCopy final : public OpRewritePattern { return failure(); } }; + +struct DestructureSingleEltCopy final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + if (copyOp.getSource().getType() == copyOp.getTarget().getType()) { + auto ty = copyOp.getSource().getType(); + if (ty.hasRank() && ty.getNumElements() == 1 && ty.hasStaticShape()) { + // copy of one element + rewriter.setInsertionPoint(copyOp); + SmallVector indices; + if (!ty.getShape().empty()) { + Value cst0 = rewriter.create( + copyOp->getLoc(), rewriter.getIndexAttr(0)); + indices.append(ty.getShape().size(), cst0); + } + auto loaded = rewriter.create( + copyOp->getLoc(), copyOp.getSource(), indices); + rewriter.create(copyOp->getLoc(), loaded.getResult(), + copyOp.getTarget(), indices); + rewriter.eraseOp(copyOp); + return success(); + } + } + return failure(); + } +}; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } LogicalResult CopyOp::fold(FoldAdaptor adaptor, @@ -1676,6 +1714,35 @@ LogicalResult LoadOp::verify() { } OpFoldResult LoadOp::fold(FoldAdaptor adaptor) { + + if (auto getCst = mlir::dyn_cast_or_null(getMemref().getDefiningOp())) { + auto global = mlir::dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(getCst, getCst.getNameAttr())); + if (global && global.getConstant() && global.getInitialValue()) { + auto constIndices = adaptor.getIndices(); + if (llvm::all_of(constIndices, [](auto attr) { + return mlir::isa(attr); + })) { + SmallVector index; + for (auto attr : constIndices) { + index.push_back(cast(attr).getUInt()); + } + // all indices are constant, value is constant + if (auto constValue = + mlir::dyn_cast(*global.getInitialValue())) { + if (constValue.isValidIndex(index)) { + auto flatIdx = constValue.getFlattenedIndex(index); + auto values = constValue.getValues(); + auto iter = values.begin(); + if (std::next(iter, flatIdx) < values.end()) { + return OpFoldResult(*iter); + } + } + } + } + } + } + /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8e0e723cf4ed3..2dc07dba62d15 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -661,8 +661,17 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, case arith::AtomicRMWKind::ori: return builder.create(vector.getLoc(), CombiningKind::OR, vector); - // TODO: Add remaining reduction operations. + case arith::AtomicRMWKind::maxnumf: + return builder.create(vector.getLoc(), + CombiningKind::MAXNUMF, vector); + case arith::AtomicRMWKind::minnumf: + return builder.create(vector.getLoc(), + CombiningKind::MINNUMF, vector); + case arith::AtomicRMWKind::assign: + (void)emitOptionalError(loc, "Reduction operation type not supported (assign)"); + break; default: + // Should this be an assert(false)? (void)emitOptionalError(loc, "Reduction operation type not supported"); break; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 67c18189b85e0..ee4710120550b 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -15,19 +15,27 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Threading.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/CrashRecoveryContext.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Threading.h" #include "llvm/Support/ToolOutputFile.h" +#include #include +#include using namespace mlir; using namespace mlir::detail; @@ -108,16 +116,19 @@ namespace detail { struct OpPassManagerImpl { OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting) : name(opName.getStringRef().str()), opName(opName), - initializationGeneration(0), nesting(nesting) {} + initializationGeneration(0), nesting(nesting), + isRecursiveAnchor(false) {} OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting) : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()), - initializationGeneration(0), nesting(nesting) {} + initializationGeneration(0), nesting(nesting), + isRecursiveAnchor(false) {} OpPassManagerImpl(OpPassManager::Nesting nesting) - : initializationGeneration(0), nesting(nesting) {} + : initializationGeneration(0), nesting(nesting), + isRecursiveAnchor(false) {} OpPassManagerImpl(const OpPassManagerImpl &rhs) : name(rhs.name), opName(rhs.opName), initializationGeneration(rhs.initializationGeneration), - nesting(rhs.nesting) { + nesting(rhs.nesting), isRecursiveAnchor(false) { for (const std::unique_ptr &pass : rhs.passes) { std::unique_ptr newPass = pass->clone(); newPass->threadingSibling = pass.get(); @@ -193,12 +204,17 @@ struct OpPassManagerImpl { /// Control the implicit nesting of passes that mismatch the name set for this /// OpPassManager. OpPassManager::Nesting nesting; + + /// Whether the anchor is recursive + bool isRecursiveAnchor; }; } // namespace detail } // namespace mlir void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) { assert(name == rhs.name && "merging unrelated pass managers"); + assert(isRecursiveAnchor == rhs.isRecursiveAnchor && + "anchor fetching method is different"); for (auto &pass : passes) rhs.passes.push_back(std::move(pass)); passes.clear(); @@ -206,7 +222,7 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) { OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) { auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); - addPass(std::unique_ptr(adaptor)); + passes.emplace_back(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); } @@ -215,9 +231,13 @@ void OpPassManagerImpl::addPass(std::unique_ptr pass) { // implicitly nest a pass manager for this operation if enabled. std::optional pmOpName = getOpName(); std::optional passOpName = pass->getOpName(); - if (pmOpName && passOpName && *pmOpName != *passOpName) { - if (nesting == OpPassManager::Nesting::Implicit) - return nest(*passOpName).addPass(std::move(pass)); + if (pmOpName && ((passOpName && *passOpName != *pmOpName) || + pass->shouldImplicitlyNestOn(*pmOpName))) { + if (nesting == OpPassManager::Nesting::Implicit) { + if (passOpName) + return nest(*passOpName).addPass(std::move(pass)); + return nestAny().addPass(std::move(pass)); + } llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() + "' restricted to '" + *passOpName + "' on a PassManager intended to run on '" + @@ -390,8 +410,11 @@ StringRef OpPassManager::getOpAnchorName() const { /// Prints out the passes of the pass manager as the textual representation /// of pipelines. void printAsTextualPipeline( - raw_ostream &os, StringRef anchorName, + raw_ostream &os, StringRef anchorName, bool hasRecursiveAnchor, const llvm::iterator_range &passes) { + if (hasRecursiveAnchor) { + os << "**"; + } os << anchorName << "("; llvm::interleave( passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); }, @@ -401,7 +424,7 @@ void printAsTextualPipeline( void OpPassManager::printAsTextualPipeline(raw_ostream &os) const { StringRef anchorName = getOpAnchorName(); ::printAsTextualPipeline( - os, anchorName, + os, anchorName, hasRecursiveAnchor(), {MutableArrayRef>{impl->passes}.begin(), MutableArrayRef>{impl->passes}.end()}); } @@ -426,6 +449,14 @@ void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; } +void OpPassManager::setRecursiveAnchorFetching(bool enabled) { + impl->isRecursiveAnchor = enabled; +} + +bool OpPassManager::hasRecursiveAnchor() const { + return impl->isRecursiveAnchor; +} + LogicalResult OpPassManager::initialize(MLIRContext *context, unsigned newInitGeneration) { if (impl->initializationGeneration == newInitGeneration) @@ -464,7 +495,6 @@ llvm::hash_code OpPassManager::hash() { return hashCode; } - //===----------------------------------------------------------------------===// // OpToOpPassAdaptor //===----------------------------------------------------------------------===// @@ -638,6 +668,10 @@ LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx, auto hasScheduleConflictWith = [&](OpPassManager &genericPM, MutableArrayRef otherPMs) { return llvm::any_of(otherPMs, [&](OpPassManager &pm) { + /// Anchor fetching methods must match + if (pm.hasRecursiveAnchor() != genericPM.hasRecursiveAnchor()) + return true; + // If this is a non-generic pass manager, a conflict will arise if a // non-generic pass manager's operation name can be scheduled on the // generic passmanager. @@ -669,11 +703,13 @@ LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx, // into it. if (auto *existingPM = findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) { - pm.getImpl().mergeInto(existingPM->getImpl()); - } else { - // Otherwise, add the given pass manager to the list. - rhs.mgrs.emplace_back(std::move(pm)); + if (existingPM->hasRecursiveAnchor() == pm.hasRecursiveAnchor()) { + pm.getImpl().mergeInto(existingPM->getImpl()); + continue; + } } + // Otherwise, add the given pass manager to the list. + rhs.mgrs.emplace_back(std::move(pm)); } mgrs.clear(); @@ -702,6 +738,11 @@ std::string OpToOpPassAdaptor::getAdaptorName() { return name; } +bool OpToOpPassAdaptor::hasRecursiveAnchor() { + return llvm::all_of( + mgrs, [](OpPassManager &pm) { return pm.hasRecursiveAnchor(); }); +} + void OpToOpPassAdaptor::runOnOperation() { llvm_unreachable( "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor"); @@ -721,18 +762,36 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) { PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), this}; auto *instrumentor = am.getPassInstrumentor(); - for (auto ®ion : getOperation()->getRegions()) { - for (auto &block : region) { - for (auto &op : block) { - auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext()); - if (!mgr) - continue; - - // Run the held pipeline over the current operation. - unsigned initGeneration = mgr->impl->initializationGeneration; - if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses, - initGeneration, instrumentor, &parentInfo))) - signalPassFailure(); + auto handleOp = [&](Operation &op) -> WalkResult { + auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext()); + if (!mgr) + return WalkResult::advance(); + + // Run the held pipeline over the current operation. + unsigned initGeneration = mgr->impl->initializationGeneration; + if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses, + initGeneration, instrumentor, &parentInfo))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + // if we could run the pipeline, we skip exploration of its subtree. + return WalkResult::skip(); + }; + + if (hasRecursiveAnchor()) { + for (auto ®ion : getOperation()->getRegions()) { + auto res = region.walk( + [&](Operation *op) -> WalkResult { return handleOp(*op); }); + if (res.wasInterrupted()) + return; + } + } else { + for (auto ®ion : getOperation()->getRegions()) { + for (auto &block : region) { + for (auto &op : block) { + if (handleOp(op).wasInterrupted()) + return; + } } } } @@ -776,18 +835,39 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) { // operation, as well as providing a queue of operations to execute over. std::vector opInfos; DenseMap> knownOpPMIdx; + + auto handleOp = [&](Operation &op) -> LogicalResult { + // Get the pass manager index for this operation type. + auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), std::nullopt); + if (pmIdxIt.second) { + if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context)) + pmIdxIt.first->second = std::distance(mgrs.begin(), mgr); + } + + // If this operation can be scheduled, add it to the list. + if (pmIdxIt.first->second) { + opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op)); + return success(); + } + return failure(); + }; + for (auto ®ion : getOperation()->getRegions()) { - for (Operation &op : region.getOps()) { - // Get the pass manager index for this operation type. - auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), std::nullopt); - if (pmIdxIt.second) { - if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context)) - pmIdxIt.first->second = std::distance(mgrs.begin(), mgr); - } - // If this operation can be scheduled, add it to the list. - if (pmIdxIt.first->second) - opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op)); + if (hasRecursiveAnchor()) { + // in that case the next nested ops to process are fetched recursively + region.walk([&](Operation *op) { + if (succeeded(handleOp(*op))) { + // if we can run the pipeline, we skip exploration of its subtree. + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + } else { + // here they are only fetched from the children + for (Operation &op : region.getOps()) { + (void)handleOp(op); + } } } @@ -847,7 +927,7 @@ void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; } LogicalResult PassManager::run(Operation *op) { MLIRContext *context = getContext(); std::optional anchorOp = getOpName(*context); - if (anchorOp && anchorOp != op->getName()) + if (anchorOp && anchorOp != op->getName() && !hasRecursiveAnchor()) return emitError(op->getLoc()) << "can't run '" << getOpAnchorName() << "' pass manager on '" << op->getName() << "' op"; @@ -869,7 +949,8 @@ LogicalResult PassManager::run(Operation *op) { // Initialize all of the passes within the pass manager with a new generation. llvm::hash_code newInitKey = context->getRegistryHash(); llvm::hash_code pipelineKey = hash(); - if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) { + if (newInitKey != initializationKey || + pipelineKey != pipelineInitializationKey) { if (failed(initialize(context, impl->initializationGeneration + 1))) return failure(); initializationKey = newInitKey; diff --git a/mlir/lib/Pass/PassCrashRecovery.cpp b/mlir/lib/Pass/PassCrashRecovery.cpp index 8c6d865cb31dd..74d9dfa679728 100644 --- a/mlir/lib/Pass/PassCrashRecovery.cpp +++ b/mlir/lib/Pass/PassCrashRecovery.cpp @@ -442,11 +442,11 @@ makeReproducerStreamFactory(StringRef outputFile) { } void printAsTextualPipeline( - raw_ostream &os, StringRef anchorName, + raw_ostream &os, StringRef anchorName, bool hasRecursiveAnchor, const llvm::iterator_range &passes); std::string mlir::makeReproducer( - StringRef anchorName, + StringRef anchorName, bool hasRecursiveAnchor, const llvm::iterator_range &passes, Operation *op, StringRef outputFile, bool disableThreads, bool verifyPasses) { @@ -454,7 +454,7 @@ std::string mlir::makeReproducer( std::string description; std::string pipelineStr; llvm::raw_string_ostream passOS(pipelineStr); - ::printAsTextualPipeline(passOS, anchorName, passes); + ::printAsTextualPipeline(passOS, anchorName, hasRecursiveAnchor, passes); appendReproducer(description, op, makeReproducerStreamFactory(outputFile), pipelineStr, disableThreads, verifyPasses); return description; diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index 5cc726295c9f1..02534a777641b 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -58,6 +58,8 @@ class OpToOpPassAdaptor /// Returns the adaptor pass name. std::string getAdaptorName(); + bool hasRecursiveAnchor(); + private: /// Run this pass adaptor synchronously. void runOnOperationImpl(bool verifyPasses); diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index ece2fdaed0dfd..ada1da8050760 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -546,9 +546,17 @@ class TextualPipeline { /// the name is the name of a pass, the InnerPipeline is empty, since passes /// cannot contain inner pipelines. struct PipelineElement { - PipelineElement(StringRef name) : name(name) {} + PipelineElement(StringRef name) { + if (name.starts_with("**")) { + this->name = name.drop_front(2); + this->hasRecursiveAnchor = true; + } else { + this->name = name; + } + } StringRef name; + bool hasRecursiveAnchor = false; StringRef options; const PassRegistryEntry *registryEntry = nullptr; std::vector innerPipeline; @@ -755,10 +763,13 @@ LogicalResult TextualPipeline::addToPipeline( return errorHandler("failed to add `" + elt.name + "` with options `" + elt.options + "`"); } - } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name), - errorHandler))) { - return errorHandler("failed to add `" + elt.name + "` with options `" + - elt.options + "` to inner pipeline"); + } else { + auto &nested = pm.nest(elt.name); + if (failed(addToPipeline(elt.innerPipeline, nested, errorHandler))) { + return errorHandler("failed to add `" + elt.name + "` with options `" + + elt.options + "` to inner pipeline"); + } + nested.setRecursiveAnchorFetching(elt.hasRecursiveAnchor); } } return success(); diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 9bbf91de18305..7241e2386471d 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -27,6 +27,8 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Visitors.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -37,6 +39,7 @@ #include "mlir/Tools/ParseUtilities.h" #include "mlir/Tools/Plugins/DialectPlugin.h" #include "mlir/Tools/Plugins/PassPlugin.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" @@ -192,6 +195,12 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { static cl::list passPlugins( "load-pass-plugin", cl::desc("Load passes from plugin library")); + static cl::opt passPipelineAnchor{ + "pass-pipeline-anchor", llvm::cl::ValueOptional, + cl::desc("Specify an operation name that will be used as the anchor of " + "the CLI pass pipeline"), + cl::location(passPipelineAnchorFlag), cl::init("")}; + static cl::opt generateReproducerFile( "mlir-generate-reproducer", @@ -285,7 +294,14 @@ MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser( emitError(UnknownLoc::get(pm.getContext())) << msg; return failure(); }; - if (failed(passPipeline.addToPipeline(pm, errorHandler))) + + OpPassManager *oppm = ± + if (auto anchor = getPassPipelineAnchor()) { + oppm = &pm.nest(*anchor); + oppm->setRecursiveAnchorFetching(true); + } + + if (failed(passPipeline.addToPipeline(*oppm, errorHandler))) return failure(); if (this->shouldDumpPassPipeline()) { @@ -414,11 +430,11 @@ static LogicalResult doVerifyRoundTrip(Operation *op, return success(succeeded(txtStatus) && succeeded(bcStatus)); } -/// Perform the actions on the input file indicated by the command line flags -/// within the specified context. +/// Perform the actions on the input file indicated by the command line +/// flags within the specified context. /// -/// This typically parses the main source file, runs zero or more optimization -/// passes, then prints the output. +/// This typically parses the main source file, runs zero or more +/// optimization passes, then prints the output. /// static LogicalResult performActions(raw_ostream &os, @@ -460,7 +476,9 @@ performActions(raw_ostream &os, context->enableMultithreading(wasThreadingEnabled); // Prepare the pass manager, applying command-line and reproducer options. - PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); + StringRef rootName = op.get()->getName().getStringRef(); + PassManager pm(context, rootName, PassManager::Nesting::Implicit); + pm.enableVerifier(config.shouldVerifyPasses()); if (failed(applyPassManagerCLOptions(pm))) return failure(); @@ -470,7 +488,7 @@ performActions(raw_ostream &os, if (failed(config.setupPassPipeline(pm))) return failure(); - // Run the pipeline. + // Run the pipeline on the root if (failed(pm.run(*op))) return failure(); @@ -478,7 +496,7 @@ performActions(raw_ostream &os, if (!config.getReproducerFilename().empty()) { StringRef anchorName = pm.getAnyOpAnchorName(); const auto &passes = pm.getPasses(); - makeReproducer(anchorName, passes, op.get(), + makeReproducer(anchorName, pm.hasRecursiveAnchor(), passes, op.get(), config.getReproducerFilename()); } diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index db8be38a51443..0e9a9dc43fe61 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -10,8 +10,10 @@ #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" +#include namespace mlir { #define GEN_PASS_DEF_SROA @@ -248,6 +250,10 @@ namespace { struct SROA : public impl::SROABase { using impl::SROABase::SROABase; + bool shouldImplicitlyNestOn(llvm::StringRef name) const final { + return name == ModuleOp::getOperationName(); + } + void runOnOperation() override { Operation *scopeOp = getOperation(); diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir new file mode 100644 index 0000000000000..d8f2aaab4839e --- /dev/null +++ b/mlir/test/Dialect/Affine/raise-memref.mlir @@ -0,0 +1,118 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s + +// CHECK-LABEL: func @reduce_window_max( +func.func @reduce_window_max() { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.alloc() : memref<1x8x8x64xf32> + %1 = memref.alloc() : memref<1x18x18x64xf32> + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 8 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 1 { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 3 { + affine.for %arg7 = 0 to 1 { + %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + %21 = arith.addi %arg0, %arg4 : index + %22 = arith.constant 2 : index + %23 = arith.muli %arg1, %22 : index + %24 = arith.addi %23, %arg5 : index + %25 = arith.muli %arg2, %22 : index + %26 = arith.addi %25, %arg6 : index + %27 = arith.addi %arg3, %arg7 : index + %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32> + %4 = arith.cmpf ogt, %2, %3 : f32 + %5 = arith.select %4, %2, %3 : f32 + memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> + } + } + } + } + } + } + } + } + return +} + +// CHECK: %[[cst:.*]] = arith.constant 0 +// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32> +// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32> +// CHECK: affine.for %[[arg0:.*]] = +// CHECK: affine.for %[[arg1:.*]] = +// CHECK: affine.for %[[arg2:.*]] = +// CHECK: affine.for %[[arg3:.*]] = +// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : +// CHECK: affine.for %[[a0:.*]] = +// CHECK: affine.for %[[a1:.*]] = +// CHECK: affine.for %[[a2:.*]] = +// CHECK: affine.for %[[a3:.*]] = +// CHECK: affine.for %[[a4:.*]] = +// CHECK: affine.for %[[a5:.*]] = +// CHECK: affine.for %[[a6:.*]] = +// CHECK: affine.for %[[a7:.*]] = +// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : +// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : +// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 +// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 +// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : + +// CHECK-LABEL: func @symbols( +func.func @symbols(%N : index) { + %0 = memref.alloc() : memref<1024x1024xf32> + %1 = memref.alloc() : memref<1024x1024xf32> + %2 = memref.alloc() : memref<1024x1024xf32> + %cst1 = arith.constant 1 : index + %cst2 = arith.constant 2 : index + affine.for %i = 0 to %N { + affine.for %j = 0 to %N { + %7 = memref.load %2[%i, %j] : memref<1024x1024xf32> + %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index { + %12 = arith.muli %N, %cst2 : index + %13 = arith.addi %12, %cst1 : index + %14 = arith.addi %13, %j : index + %5 = memref.load %0[%i, %12] : memref<1024x1024xf32> + %6 = memref.load %1[%14, %j] : memref<1024x1024xf32> + %8 = arith.mulf %5, %6 : f32 + %9 = arith.addf %7, %8 : f32 + %4 = arith.addi %N, %cst1 : index + %11 = arith.addi %ax, %cst1 : index + memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol + memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised + %something = "ab.v"() : () -> index + memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered + affine.yield %11 : index + } + } + } + return +} + +// CHECK: %[[cst1:.*]] = arith.constant 1 : index +// CHECK: %[[v0:.*]] = memref.alloc() : memref< +// CHECK: %[[v1:.*]] = memref.alloc() : memref< +// CHECK: %[[v2:.*]] = memref.alloc() : memref< +// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 { +// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 { +// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32> +// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) { +// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : +// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : +// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]] +// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]] +// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]] +// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : +// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : +// CHECK: %[[lhs7:.*]] = "ab.v" +// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : +// CHECK: affine.yield %[[lhs6]] diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir index 092597860c8d9..6b9d4fc1ac15f 100644 --- a/mlir/test/Dialect/Affine/scalrep.mlir +++ b/mlir/test/Dialect/Affine/scalrep.mlir @@ -141,9 +141,14 @@ func.func @store_load_store_nested_no_fwd(%N : index) { affine.for %i0 = 0 to 10 { affine.store %cf7, %m[%i0] : memref<10xf32> affine.for %i1 = 0 to %N { - // CHECK: %{{[0-9]+}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK: %[[C7:.*]] = arith.constant 7.0{{.*}} + // CHECK: %[[C9:.*]] = arith.constant 9.0{{.*}} + // CHECK: %{{[0-9]+}} = affine.for %{{.*}} = 0 to %{{.*}} iter_args(%[[A:.*]] = %[[C7]]) -> (f32) + // CHECK-NEXT: %[[R:.*]] = arith.addf %[[A]], %[[A]] : f32 + // CHECK: affine.yield %[[C9]] : f32 %v0 = affine.load %m[%i0] : memref<10xf32> %v1 = arith.addf %v0, %v0 : f32 + "use"(%v1) : (f32) -> () affine.store %cf9, %m[%i0] : memref<10xf32> } } @@ -423,7 +428,8 @@ func.func @load_load_store_2_loops_no_cse(%N : index, %m : memref<10xf32>) { // CHECK: affine.load %v0 = affine.load %m[%i0] : memref<10xf32> affine.for %i1 = 0 to %N { - // CHECK: affine.load + // CHECK: iter_args + // CHECK-NOT: affine.load %v1 = affine.load %m[%i0] : memref<10xf32> %v2 = arith.addf %v0, %v1 : f32 affine.store %v2, %m[%i0] : memref<10xf32> @@ -556,10 +562,11 @@ func.func @reduction_multi_store() -> memref<1xf32> { "test.foo"(%m) : (f32) -> () } -// CHECK: affine.for -// CHECK: affine.load -// CHECK: affine.store %[[S:.*]], -// CHECK-NEXT: "test.foo"(%[[S]]) +// CHECK: affine.for {{.*}} +// CHECK-NEXT: %[[A:.*]] = affine.load +// CHECK-NEXT: %[[X:.*]] = arith.addf %[[A]], +// CHECK-NEXT: affine.store %[[X]] +// CHECK-NEXT: "test.foo"(%[[X]]) return %A : memref<1xf32> } @@ -890,6 +897,34 @@ func.func @parallel_surrounding_for() { // CHECK-NEXT: return } +// CHECK-LABEL: func @reduction_extraction +func.func @reduction_extraction(%x : memref<10x10xf32>) -> f32 { + %b = memref.alloc() : memref + %cst = arith.constant 0.0 : f32 + affine.store %cst, %b[] : memref + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %v0 = affine.load %x[%i0,%i1] : memref<10x10xf32> + %acc = affine.load %b[] : memref + %v1 = arith.addf %acc, %v0 : f32 + affine.store %v1, %b[] : memref + } + } + %x2 = affine.load %b[]: memref + return %x2 : f32 +// CHECK: %[[I:.*]] = arith.constant 0{{.*}} : f32 +// CHECK-NEXT: %[[SUM2:.*]] = affine.for %{{.*}} = 0 to 10 iter_args(%[[ACC2:.*]] = %[[I]]) -> (f32) { +// CHECK-NEXT: %[[SUM:.*]] = affine.for %{{.*}} = 0 to 10 iter_args(%[[ACC:.*]] = %[[ACC2]]) -> (f32) { +// CHECK-NEXT: %[[X:.*]] = affine.load {{.*}} : memref<10x10xf32> +// CHECK-NEXT: %[[Y:.*]] = arith.addf %[[ACC]], %[[X]] : f32 +// CHECK-NEXT: affine.yield %[[Y]] : f32 +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %[[SUM]] : f32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[SUM2]] : f32 +} + + // CHECK-LABEL: func.func @dead_affine_region_op func.func @dead_affine_region_op() { %c1 = arith.constant 1 : index diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index dc556761b09e5..9459309eb4c0d 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -497,6 +497,48 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>, // ----- +func.func @reduce_asymmetric(%input: tensor<16x32x64xi32>, %input2: tensor<16x32x64xi32>, + %init: tensor<16x64xi32>) -> tensor<16x64xi32> { + %reduce = linalg.reduce + ins(%input, %input2:tensor<16x32x64xi32>, tensor<16x32x64xi32>) + outs(%init:tensor<16x64xi32>) + dimensions = [1] + (%in: i32, %in2: i32, %out: i32) { + %0 = arith.muli %in, %in2: i32 + %1 = arith.addi %out, %0: i32 + linalg.yield %1: i32 + } + func.return %reduce : tensor<16x64xi32> +} +// CHECK-LABEL: func @reduce_asymmetric +// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: tensor<16x32x64xi32>, tensor<16x32x64xi32>) +// CHECK-NOT: operandSegmentSize +// CHECK-SAME: outs(%{{.*}}: tensor<16x64xi32>) +// CHECK-SAME: dimensions = [1] + +// ----- + +func.func @reduce_asymmetric_memref(%input: memref<16x32x64xi32>, %input2: memref<16x32x64xi32>, + %init: memref<16x64xi32>) { + linalg.reduce + ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>) + outs(%init:memref<16x64xi32>) + dimensions = [1] + (%in: i32, %in2: i32, %out: i32) { + %0 = arith.muli %in, %in2: i32 + %1 = arith.addi %out, %0: i32 + linalg.yield %1: i32 + } + func.return +} +// CHECK-LABEL: func @reduce_asymmetric_memref +// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: memref<16x32x64xi32>, memref<16x32x64xi32>) +// CHECK-NOT: operandSegmentSize +// CHECK-SAME: outs(%{{.*}}: memref<16x64xi32>) +// CHECK-SAME: dimensions = [1] + +// ----- + func.func @transpose(%input: tensor<16x32x64xf32>, %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { %transpose = linalg.transpose diff --git a/mlir/test/Pass/recursive-pipeline-anchor.mlir b/mlir/test/Pass/recursive-pipeline-anchor.mlir new file mode 100644 index 0000000000000..f8464fe77ae22 --- /dev/null +++ b/mlir/test/Pass/recursive-pipeline-anchor.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(**func.func(test-function-pass))' -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s +// RUN: mlir-opt %s -mlir-disable-threading -test-function-pass --pass-pipeline-anchor=func.func -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s + +// some with threading enabled + +// RUN: mlir-opt %s -pass-pipeline='builtin.module(**func.func(test-function-pass))' -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s +// RUN: mlir-opt %s -test-function-pass --pass-pipeline-anchor=func.func -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s + +// some without recursion + +// RUN: mlir-opt %s -mlir-disable-threading -test-function-pass -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s --check-prefix=NON_REC_CHECK +// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(func.func(test-function-pass))' -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s --check-prefix=NON_REC_CHECK + +func.func @foo() { + return +} + +module { + func.func @bar() { + return + } +} + +// with recursive anchor the pass is executed on @foo and @bar + +// CHECK: TestFunctionPass +// CHECK-NEXT: (S) 2 counter - Number of invocations + +// in non-recursive mode the pass is only executed on @foo + +// NON_REC_CHECK: TestFunctionPass +// NON_REC_CHECK-NEXT: (S) 1 counter - Number of invocations \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp b/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp index 751302550092d..f48091c9c0893 100644 --- a/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp @@ -12,7 +12,7 @@ #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/LoopFusionUtils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Pass/Pass.h" #define PASS_NAME "test-affine-access-analysis" @@ -23,7 +23,7 @@ using namespace mlir::affine; namespace { struct TestAccessAnalysis - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAccessAnalysis) StringRef getArgument() const final { return PASS_NAME; } @@ -52,7 +52,7 @@ void TestAccessAnalysis::runOnOperation() { SmallVector enclosingOps; // Go over all top-level affine.for ops and test each contained affine // access's contiguity along every surrounding loop IV. - for (auto forOp : getOperation().getOps()) { + for (auto forOp : getOperation()->getRegion(0).getOps()) { loadStores.clear(); gatherLoadsAndStores(forOp, loadStores); for (Operation *memOp : loadStores) { diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp index 404f34ebee17a..742f9dae9e619 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -14,7 +14,7 @@ #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -28,7 +28,7 @@ using namespace mlir::affine; namespace { struct TestAffineDataCopy - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineDataCopy) StringRef getArgument() const final { return PASS_NAME; } diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp index f8e76356c4321..8e47e7c25da2e 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp @@ -23,8 +23,7 @@ using namespace mlir::affine; namespace { struct TestAffineLoopParametricTiling - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineLoopParametricTiling) StringRef getArgument() const final { return "test-affine-parametric-tile"; } diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp index 7e4a3ca7b7c72..87c10b74b0cb7 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" @@ -25,7 +26,7 @@ namespace { /// This pass applies the permutation on the first maximal perfect nest. struct TestAffineLoopUnswitching - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineLoopUnswitching) StringRef getArgument() const final { return PASS_NAME; } diff --git a/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp b/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp index 429784f26e038..1faf01f51ec25 100644 --- a/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp +++ b/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp @@ -26,7 +26,7 @@ using namespace mlir::affine; namespace { struct TestDecomposeAffineOps - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeAffineOps) StringRef getArgument() const final { return PASS_NAME; } @@ -43,7 +43,7 @@ struct TestDecomposeAffineOps void TestDecomposeAffineOps::runOnOperation() { IRRewriter rewriter(&getContext()); - this->getOperation().walk([&](AffineApplyOp op) { + this->getOperation()->walk([&](AffineApplyOp op) { rewriter.setInsertionPoint(op); reorderOperandsByHoistability(rewriter, op); (void)decompose(rewriter, op); diff --git a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp index 19011803a793a..0352c84391d10 100644 --- a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp +++ b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopFusionUtils.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" @@ -25,7 +26,7 @@ using namespace mlir::affine; namespace { struct TestLoopFusion - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion) StringRef getArgument() const final { return "test-loop-fusion"; } diff --git a/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp b/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp index 3dc7abb15af17..429f7ea42cbe8 100644 --- a/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp +++ b/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" @@ -22,7 +23,7 @@ using namespace mlir::affine; namespace { struct TestLoopMappingPass - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopMappingPass) StringRef getArgument() const final { diff --git a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp index e708b7de690ec..7e8a6779ea7d7 100644 --- a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp +++ b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp @@ -13,7 +13,7 @@ #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Pass/Pass.h" +#include "mlir/Dialect/Affine/Passes.h" #define PASS_NAME "test-loop-permutation" @@ -24,7 +24,7 @@ namespace { /// This pass applies the permutation on the first maximal perfect nest. struct TestLoopPermutation - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopPermutation) StringRef getArgument() const final { return PASS_NAME; } diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 891b3bab8629d..ed05c92d48491 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" @@ -31,8 +32,7 @@ namespace { /// This pass applies the permutation on the first maximal perfect nest. struct TestReifyValueBounds - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds) StringRef getArgument() const final { return PASS_NAME; } @@ -76,11 +76,11 @@ invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) { /// Look for "test.reify_bound" ops in the input and replace their results with /// the reified values. -static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, +static LogicalResult testReifyValueBounds(Operation* funcOp, bool reifyToFuncArgs, bool useArithOps) { - IRRewriter rewriter(funcOp.getContext()); - WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) { + IRRewriter rewriter(funcOp->getContext()); + WalkResult result = funcOp->walk([&](test::ReifyBoundOp op) { auto boundType = op.getBoundType(); Value value = op.getVar(); std::optional dim = op.getDim(); @@ -158,9 +158,9 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, } /// Look for "test.compare" ops and emit errors/remarks. -static LogicalResult testEquality(FunctionOpInterface funcOp) { - IRRewriter rewriter(funcOp.getContext()); - WalkResult result = funcOp.walk([&](test::CompareOp op) { +static LogicalResult testEquality(Operation* funcOp) { + IRRewriter rewriter(funcOp->getContext()); + WalkResult result = funcOp->walk([&](test::CompareOp op) { auto cmpType = op.getComparisonOperator(); if (op.getCompose()) { if (cmpType != ValueBoundsConstraintSet::EQ) { diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp index c32bd24014215..d5827f8989e04 100644 --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -24,7 +25,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" @@ -39,7 +39,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); namespace { struct VectorizerTestPass - : public PassWrapper> { + : public PassWrapper { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizerTestPass) static constexpr auto kTestAffineMapOpName = "test_affine_map"; @@ -53,7 +53,7 @@ struct VectorizerTestPass } VectorizerTestPass() = default; - VectorizerTestPass(const VectorizerTestPass &pass) : PassWrapper(pass){}; + VectorizerTestPass(const VectorizerTestPass &pass) : PassWrapper(pass) {}; ListOption clTestVectorShapeRatio{ *this, "vector-shape-ratio", @@ -97,11 +97,12 @@ struct VectorizerTestPass } // namespace void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) { - auto f = getOperation(); + auto *f = getOperation(); using affine::matcher::Op; SmallVector shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); - auto subVectorType = VectorType::get(shape, Float32Type::get(f.getContext())); + auto subVectorType = + VectorType::get(shape, FloatType::getF32(f->getContext())); // Only filter operations that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [&](Operation &op) { @@ -147,7 +148,7 @@ static NestedPattern patternTestSlicingOps() { } void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { - auto f = getOperation(); + auto f = cast(getOperation()); outs << "\n" << f.getName(); SmallVector matches; @@ -163,7 +164,7 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { } void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) { - auto f = getOperation(); + auto f = cast(getOperation()); outs << "\n" << f.getName(); SmallVector matches; @@ -179,7 +180,7 @@ void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) { } void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) { - auto f = getOperation(); + auto f = cast(getOperation()); outs << "\n" << f.getName(); SmallVector matches; @@ -198,7 +199,7 @@ static bool customOpWithAffineMapAttribute(Operation &op) { } void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { - auto f = getOperation(); + auto f = cast(getOperation()); using affine::matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); @@ -252,7 +253,7 @@ void VectorizerTestPass::testVecAffineLoopNest(llvm::raw_ostream &outs) { void VectorizerTestPass::runOnOperation() { // Only support single block functions at this point. - func::FuncOp f = getOperation(); + func::FuncOp f = cast(getOperation()); if (!llvm::hasSingleElement(f)) return; diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 7afe2109f04db..9289300f5ba4c 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -29,8 +29,14 @@ struct TestModulePass struct TestFunctionPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFunctionPass) + TestFunctionPass() = default; + TestFunctionPass(const TestFunctionPass& pass) {} - void runOnOperation() final {} + Statistic callCount{this, "counter", "Number of invocations"}; + + void runOnOperation() final { + callCount++; + } StringRef getArgument() const final { return "test-function-pass"; } StringRef getDescription() const final { return "Test a function pass in the pass manager";