Skip to content

Commit 0a8e3cc

Browse files
authored
[AMD][NFC] Extract range analysis into its own class (triton-lang#5977)
This PR factors `TritonIntegerRangeAnalysis` out of `ConvertBufferOps` into a standalone analysis that can be reused in other passes.
1 parent edd346d commit 0a8e3cc

File tree

15 files changed

+1643
-534
lines changed

15 files changed

+1643
-534
lines changed

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE
1313
# tests
1414
TritonTestAnalysis
1515
TritonTestDialectTritonGPU
16+
TritonAMDGPUTestAnalysis
1617
# MLIR core
1718
MLIROptLib
1819
MLIRPass
@@ -32,6 +33,7 @@ target_link_libraries(triton-reduce PRIVATE
3233
# tests
3334
TritonTestAnalysis
3435
TritonTestDialectTritonGPU
36+
TritonAMDGPUTestAnalysis
3537
# MLIR core
3638
MLIRReduceLib
3739
MLIRPass
@@ -50,6 +52,7 @@ target_link_libraries(triton-lsp PRIVATE
5052
# tests
5153
TritonTestAnalysis
5254
TritonTestDialectTritonGPU
55+
TritonAMDGPUTestAnalysis
5356
# MLIR core
5457
MLIRLspServerLib
5558
MLIRPass
@@ -86,4 +89,5 @@ target_link_libraries(triton-tensor-layout PRIVATE
8689
${dialect_libs}
8790
TritonTestAnalysis
8891
TritonTestDialectTritonGPU
92+
TritonAMDGPUTestAnalysis
8993
)

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ void registerTestAliasPass();
3232
void registerTestAlignmentPass();
3333
void registerTestAllocationPass();
3434
void registerTestMembarPass();
35+
void registerTestTritonAMDGPURangeAnalysis();
3536
} // namespace test
3637
} // namespace mlir
3738

@@ -44,6 +45,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4445
mlir::test::registerTestAlignmentPass();
4546
mlir::test::registerTestAllocationPass();
4647
mlir::test::registerTestMembarPass();
48+
mlir::test::registerTestTritonAMDGPURangeAnalysis();
4749
mlir::triton::registerConvertTritonToTritonGPUPass();
4850
mlir::triton::registerAllocateSharedMemoryPass();
4951
mlir::triton::registerTritonGPUGlobalScratchAllocationPass();

test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir

Lines changed: 185 additions & 185 deletions
Large diffs are not rendered by default.

test/TritonGPU/amd/amd-range-analysis.mlir

Lines changed: 931 additions & 0 deletions
Large diffs are not rendered by default.

third_party/amd/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ endif()
99
if(TRITON_BUILD_UT)
1010
add_subdirectory(unittest)
1111
endif()
12+
add_subdirectory(test)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#ifndef TRITONAMD_ANALYSIS_RANGE_ANALYSIS_H
2+
#define TRITONAMD_ANALYSIS_RANGE_ANALYSIS_H
3+
4+
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
5+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
6+
#include "mlir/Interfaces/LoopLikeInterface.h"
7+
8+
namespace mlir::triton::AMD {
9+
/// This struct (analysis) adapt's upstream's IntegerRangeAnalysis (inferring
10+
/// lower/upperbounds on integer constants) to our needs.
11+
/// Specifically there are 2 points of extension:
12+
///
13+
/// 1. Support for GetProgramIdOp, MakeRangeOp, SplatOp, ExpandDimsOp. *Note*,
14+
/// upstream already supports range inference for shaped types such as tensors
15+
/// (here we just implement effectively implement the interfaces for our ops).
16+
/// * Upstream's semantics for "range of shape type" is union over ranges of
17+
/// elements.
18+
/// * We do not use tablegen to implement
19+
/// DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
20+
/// in order to keep the entire implementation contained/encapsulated.
21+
///
22+
/// 2. Support for inference "through loops". Upstream's analysis conservatively
23+
/// inferences [min_int, max_int] for loop carried values (and therefore loop
24+
/// body values). Here we attempt to do better by analysis the loop bounds and
25+
/// "abstractly interpreting" the loop when loop bounds are statically known.
26+
/// See visitRegionSuccessors.
27+
struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
28+
using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis;
29+
30+
llvm::SmallDenseMap<LoopLikeOpInterface, int64_t> loopTripCounts;
31+
llvm::SmallDenseMap<
32+
std::pair<LoopLikeOpInterface, dataflow::IntegerValueRangeLattice *>,
33+
int64_t>
34+
loopVisits;
35+
36+
void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override;
37+
38+
LogicalResult visitOperation(
39+
Operation *op,
40+
ArrayRef<const dataflow::IntegerValueRangeLattice *> operands,
41+
ArrayRef<dataflow::IntegerValueRangeLattice *> results) override;
42+
43+
/// This method (which overloads
44+
/// AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors) implements
45+
/// "abstract interpretation" of loops with statically known bounds in order
46+
/// to infer tight ranges for loop carried values (and therefore loop body
47+
/// values). By "abstract interpretation" we mean lattice states are
48+
/// propagated to all region successors N times, where N is the total trip
49+
/// count of the loop. Recall for scf.for, both the loop itself and the users
50+
/// of the loop successors. Thus, after N propagations both loop body values
51+
/// and users of loop results will have accurate ranges (assuming we have
52+
/// implemented support for range analysis on the ops).
53+
/// *Note*, this implementation is majority similar to
54+
/// AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors (so check
55+
/// there for more explanation/insight) and basically only does two things
56+
/// differently:
57+
///
58+
/// 1. If the branch op is a loop (LoopLikeOpInterface) then we attempt to
59+
/// compute its total trip count (nested loop trip counts multiply) and
60+
/// initialize a visit count to 0. Note, due to how Dataflow analysis works we
61+
/// have to actually visit the loop N times for each iter_arg (each argument
62+
/// lattice) so we actually track visit count for (loop, arg) not just (loop).
63+
///
64+
/// 2. Before propagating, we check if we have propagated for (loop, arg) >= N
65+
/// times. If so, we do not propagate (and thus the traversal converges/ends).
66+
///
67+
/// Note, for loops where the trip count cannot be inferred *and* loops with a
68+
/// total trip count larger than `kDefaultMaxTripCount`, fallback to
69+
/// upstream's conservative inference (i.e., we infer [min_int, max_int]) for
70+
/// the loop operands and all users and all users of the results of the loop.
71+
void visitRegionSuccessors(
72+
ProgramPoint *point, RegionBranchOpInterface branch,
73+
RegionBranchPoint successor,
74+
ArrayRef<dataflow::AbstractSparseLattice *> abstractLattices) override;
75+
};
76+
77+
// TODO(max): remove after we catch up to
78+
// https://github.com/llvm/llvm-project/pull/127888
79+
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v);
80+
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op);
81+
82+
} // namespace mlir::triton::AMD
83+
84+
#endif
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
add_triton_library(TritonAMDAnalysis
2+
RangeAnalysis.cpp
3+
4+
DEPENDS
5+
TritonTableGen
6+
TritonGPUTableGen
7+
TritonGPUAttrDefsIncGen
8+
TritonGPUTypeInterfacesIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRAnalysis
12+
MLIRLLVMDialect
13+
TritonIR
14+
TritonGPUIR
15+
)

0 commit comments

Comments
 (0)