|
| 1 | +#ifndef TRITONAMD_ANALYSIS_RANGE_ANALYSIS_H |
| 2 | +#define TRITONAMD_ANALYSIS_RANGE_ANALYSIS_H |
| 3 | + |
| 4 | +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
| 5 | +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
| 6 | +#include "mlir/Interfaces/LoopLikeInterface.h" |
| 7 | + |
| 8 | +namespace mlir::triton::AMD { |
| 9 | +/// This struct (analysis) adapt's upstream's IntegerRangeAnalysis (inferring |
| 10 | +/// lower/upperbounds on integer constants) to our needs. |
| 11 | +/// Specifically there are 2 points of extension: |
| 12 | +/// |
| 13 | +/// 1. Support for GetProgramIdOp, MakeRangeOp, SplatOp, ExpandDimsOp. *Note*, |
| 14 | +/// upstream already supports range inference for shaped types such as tensors |
| 15 | +/// (here we just implement effectively implement the interfaces for our ops). |
| 16 | +/// * Upstream's semantics for "range of shape type" is union over ranges of |
| 17 | +/// elements. |
| 18 | +/// * We do not use tablegen to implement |
| 19 | +/// DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]> |
| 20 | +/// in order to keep the entire implementation contained/encapsulated. |
| 21 | +/// |
| 22 | +/// 2. Support for inference "through loops". Upstream's analysis conservatively |
| 23 | +/// inferences [min_int, max_int] for loop carried values (and therefore loop |
| 24 | +/// body values). Here we attempt to do better by analysis the loop bounds and |
| 25 | +/// "abstractly interpreting" the loop when loop bounds are statically known. |
| 26 | +/// See visitRegionSuccessors. |
| 27 | +struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { |
| 28 | + using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis; |
| 29 | + |
| 30 | + llvm::SmallDenseMap<LoopLikeOpInterface, int64_t> loopTripCounts; |
| 31 | + llvm::SmallDenseMap< |
| 32 | + std::pair<LoopLikeOpInterface, dataflow::IntegerValueRangeLattice *>, |
| 33 | + int64_t> |
| 34 | + loopVisits; |
| 35 | + |
| 36 | + void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override; |
| 37 | + |
| 38 | + LogicalResult visitOperation( |
| 39 | + Operation *op, |
| 40 | + ArrayRef<const dataflow::IntegerValueRangeLattice *> operands, |
| 41 | + ArrayRef<dataflow::IntegerValueRangeLattice *> results) override; |
| 42 | + |
| 43 | + /// This method (which overloads |
| 44 | + /// AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors) implements |
| 45 | + /// "abstract interpretation" of loops with statically known bounds in order |
| 46 | + /// to infer tight ranges for loop carried values (and therefore loop body |
| 47 | + /// values). By "abstract interpretation" we mean lattice states are |
| 48 | + /// propagated to all region successors N times, where N is the total trip |
| 49 | + /// count of the loop. Recall for scf.for, both the loop itself and the users |
| 50 | + /// of the loop successors. Thus, after N propagations both loop body values |
| 51 | + /// and users of loop results will have accurate ranges (assuming we have |
| 52 | + /// implemented support for range analysis on the ops). |
| 53 | + /// *Note*, this implementation is majority similar to |
| 54 | + /// AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors (so check |
| 55 | + /// there for more explanation/insight) and basically only does two things |
| 56 | + /// differently: |
| 57 | + /// |
| 58 | + /// 1. If the branch op is a loop (LoopLikeOpInterface) then we attempt to |
| 59 | + /// compute its total trip count (nested loop trip counts multiply) and |
| 60 | + /// initialize a visit count to 0. Note, due to how Dataflow analysis works we |
| 61 | + /// have to actually visit the loop N times for each iter_arg (each argument |
| 62 | + /// lattice) so we actually track visit count for (loop, arg) not just (loop). |
| 63 | + /// |
| 64 | + /// 2. Before propagating, we check if we have propagated for (loop, arg) >= N |
| 65 | + /// times. If so, we do not propagate (and thus the traversal converges/ends). |
| 66 | + /// |
| 67 | + /// Note, for loops where the trip count cannot be inferred *and* loops with a |
| 68 | + /// total trip count larger than `kDefaultMaxTripCount`, fallback to |
| 69 | + /// upstream's conservative inference (i.e., we infer [min_int, max_int]) for |
| 70 | + /// the loop operands and all users and all users of the results of the loop. |
| 71 | + void visitRegionSuccessors( |
| 72 | + ProgramPoint *point, RegionBranchOpInterface branch, |
| 73 | + RegionBranchPoint successor, |
| 74 | + ArrayRef<dataflow::AbstractSparseLattice *> abstractLattices) override; |
| 75 | +}; |
| 76 | + |
| 77 | +// TODO(max): remove after we catch up to |
| 78 | +// https://github.com/llvm/llvm-project/pull/127888 |
| 79 | +LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v); |
| 80 | +LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op); |
| 81 | + |
| 82 | +} // namespace mlir::triton::AMD |
| 83 | + |
| 84 | +#endif |
0 commit comments