Skip to content

Commit 40392dc

Browse files
Merge commit 'e00903a0ab6843b56962d1b7899a3a3569e6d801'
2 parents 4bbf937 + e00903a commit 40392dc

36 files changed

+1456
-665
lines changed

.github/workflows/integration-tests.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ jobs:
279279
ctest -j32
280280
- name: Run Proton tests
281281
run: |
282-
cd third_party/proton
283-
python3 -m pytest -s test
282+
cd third_party/proton/test
283+
python3 -m pytest -s .
284+
cd ..
284285
- # If we're on branch `main`, save the ccache Triton compilation artifacts
285286
# to the cache so they can be used by other (non-main) CI runs.
286287
#
@@ -425,8 +426,9 @@ jobs:
425426
python3 -m pytest -s -n 8 ./test_cast_matmul.py
426427
- name: Run Proton tests
427428
run: |
428-
cd third_party/proton
429-
python3 -m pytest -s test
429+
cd third_party/proton/test
430+
python3 -m pytest -s .
431+
cd ..
430432
- name: Run C++ unittests
431433
run: |
432434
cd python

.github/workflows/integration-tests.yml.in

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,9 @@ jobs:
319319
- &run-proton-tests-step
320320
name: Run Proton tests
321321
run: |
322-
cd third_party/proton
323-
python3 -m pytest -s test
322+
cd third_party/proton/test
323+
python3 -m pytest -s .
324+
cd ..
324325

325326
# If we're on branch `main`, save the ccache Triton compilation artifacts
326327
# to the cache so they can be used by other (non-main) CI runs.

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
280280
def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[MemAlloc<GlobalMemory>]>]> {
281281
let summary = "allocate a global memory buffer";
282282
let description = [{
283-
This operation allocates a buffer in global memory.
283+
This operation allocates a buffer in global memory that is private to the current program.
284284
}];
285285
let arguments = (
286286
ins

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,17 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init"
179179
"mlir::triton::TritonDialect"];
180180
}
181181

182+
def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> {
183+
let summary = "Generate loop scheduling for SWP";
184+
185+
let description = "This pass sets up stages and clustering for software pipelining.";
186+
187+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
188+
"mlir::triton::TritonDialect"];
189+
let options = [
190+
Option<"numStages", "num-stages",
191+
"int32_t", /*default*/"3",
192+
"number of pipeline stages">
193+
];
194+
}
182195
#endif

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace mlir {
88
namespace triton {
99

1010
static const char *kNumStagesAttrName = "tt.num_stages";
11+
static const char *kLoopStageAttrName = "loop.stage";
12+
static const char *kLoopClusterAttrName = "loop.cluster";
1113

1214
/// Function to mask operations during scheduling.
1315
Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred);
@@ -29,6 +31,11 @@ void addOps(scf::ForOp forOp, int stage,
2931
/// mutable.
3032
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3133
Value val);
34+
35+
// Return the minClusterId and maxClusterId for the given ForOp.
36+
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
37+
std::pair<int, int> getStageCluster(Operation *op);
38+
void setStageCluster(scf::ForOp &forOp, Operation *op, int stage, int cluster);
3239
} // namespace triton
3340
} // namespace mlir
3441

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,16 @@ class CoarseSchedule {
101101
createFinalSchedule(scf::ForOp forOp);
102102
void dump();
103103
bool empty() { return opToStageAndCluster.size() == 0; }
104+
void serialize(scf::ForOp &forOp);
105+
// Create a CoarseSchedule based on forOp's <stage, cluster>.
106+
void deSerialize(scf::ForOp &forOp);
104107
};
105108

109+
// Add dependencies of anchor ops to the coarse schedule. Schedule them to
110+
// the same stage and ordering cluster as the anchor op.
111+
void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule,
112+
int numStages);
113+
106114
} // namespace triton
107115
} // namespace mlir
108116
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ bool isPureUnaryInlineAsm(Operation *op);
192192
// read the compute capability from the module attributes
193193
int getNVIDIAComputeCapability(Operation *module);
194194

195+
std::optional<mlir::triton::gpu::SharedEncodingAttr>
196+
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
197+
198+
bool loadIsMMAv3(Operation *loadOp);
195199
} // namespace mlir
196200

197201
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ add_subdirectory(Conversion)
33
add_subdirectory(Dialect)
44
add_subdirectory(Target)
55
add_subdirectory(Tools)
6+
add_subdirectory(Instrumentation)

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_triton_library(TritonGPUTransforms
33
Coalesce.cpp
44
F32DotTC.cpp
55
CombineTensorSelectAndIf.cpp
6+
LoopScheduling.cpp
67
ReduceDataDuplication.cpp
78
OptimizeAccumulatorInit.cpp
89
OptimizeDotOperands.cpp

0 commit comments

Comments
 (0)