Skip to content

Commit 4a0c8fc

Browse files
Merge commit '557a7094e6fc742360297905297b0b50e0bc8088'
2 parents b665d5a + 557a709 commit 4a0c8fc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1217
-623
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ repos:
4040
hooks:
4141
- id: clang-format
4242

43+
- repo: https://github.com/pre-commit/mirrors-mypy
44+
rev: "v1.15.0"
45+
hooks:
46+
- id: mypy
47+
pass_filenames: false
48+
4349
# Expand YAML anchors in files used by github workflows, because github can't
4450
# do this itself. This lets us use anchors, which avoids code duplication.
4551
- repo: local

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,49 @@ SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
729729

730730
Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter);
731731

732+
inline std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
733+
switch (atomicOp) {
734+
case RMWOp::AND:
735+
return LLVM::AtomicBinOp::_and;
736+
case RMWOp::OR:
737+
return LLVM::AtomicBinOp::_or;
738+
case RMWOp::XOR:
739+
return LLVM::AtomicBinOp::_xor;
740+
case RMWOp::ADD:
741+
return LLVM::AtomicBinOp::add;
742+
case RMWOp::FADD:
743+
return LLVM::AtomicBinOp::fadd;
744+
case RMWOp::MAX:
745+
return LLVM::AtomicBinOp::max;
746+
case RMWOp::MIN:
747+
return LLVM::AtomicBinOp::min;
748+
case RMWOp::UMAX:
749+
return LLVM::AtomicBinOp::umax;
750+
case RMWOp::UMIN:
751+
return LLVM::AtomicBinOp::umin;
752+
case RMWOp::XCHG:
753+
return LLVM::AtomicBinOp::xchg;
754+
default:
755+
return {};
756+
}
757+
}
758+
759+
inline std::optional<LLVM::AtomicOrdering>
760+
getMemoryOrdering(MemSemantic memOrdering) {
761+
switch (memOrdering) {
762+
case MemSemantic::RELAXED:
763+
return LLVM::AtomicOrdering::monotonic;
764+
case MemSemantic::ACQUIRE:
765+
return LLVM::AtomicOrdering::acquire;
766+
case MemSemantic::RELEASE:
767+
return LLVM::AtomicOrdering::release;
768+
case MemSemantic::ACQUIRE_RELEASE:
769+
return LLVM::AtomicOrdering::acq_rel;
770+
default:
771+
return {};
772+
}
773+
}
774+
732775
inline bool
733776
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
734777
ArrayRef<int64_t> allocShape,

include/triton/Dialect/Triton/IR/TritonDialect.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def Triton_Dialect : Dialect {
4545

4646
let discardableAttrs = (ins
4747
"::mlir::IntegerAttr":$num_stages,
48-
"::mlir::IntegerAttr":$latency
48+
"::mlir::IntegerAttr":$latency,
49+
"::mlir::IntegerAttr":$self_latency
4950
);
5051

5152
let hasConstantMaterializer = 1;

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,37 @@ namespace triton::nvidia_gpu {
2020
// Given an MMAv5 operation in a loop, determine if its accumulator can be
2121
// multibuffered.
2222
bool isAccMultibufferingPossible(MMAv5OpInterface mma, scf::ForOp forOp);
23-
// Only pipeline the loops where the MMA happens before the tmem_load, or is in
24-
// the same stage as the tmem_load. Lowering does not support the case where the
25-
// MMA is in a different stage as the tmem_load and happens after it.
26-
bool mmav5DominatesTmemLoads(
27-
scf::ForOp forOp, function_ref<bool(MMAv5OpInterface)> isMmaPipelineable);
23+
24+
// Returns true if the MMA operation requires acc multi-buffering when
25+
// pipelined.
26+
bool requiresAccMultiBuffering(MMAv5OpInterface mma, scf::ForOp forOp);
27+
28+
// Returns true if there are loads from tmem after the MMA operation.
29+
bool hasLoadsAfterMMA(MMAv5OpInterface mma, scf::ForOp forOp);
30+
31+
// Helper class to determine if the operands of an MMA operation are
32+
// pipelineable.
33+
class MMAv5PipelineableOperandsHelper {
34+
public:
35+
MMAv5PipelineableOperandsHelper(
36+
MMAv5OpInterface mmaOp, scf::ForOp forOp,
37+
std::function<bool(Operation *)> isLoadToBePipelined)
38+
: mmaOp(mmaOp), forOp(forOp), isLoadToBePipelined(isLoadToBePipelined) {
39+
run();
40+
}
41+
bool isPipelineable = false;
42+
// If true, the existing operand loads are all been found and their
43+
// pipelineability has been determined.
44+
bool isOperandsStateDetermined = false;
45+
SmallVector<Operation *> unpipelineableOperandLoads;
46+
47+
private:
48+
MMAv5OpInterface mmaOp;
49+
scf::ForOp forOp;
50+
std::function<bool(Operation *)> isLoadToBePipelined;
51+
bool comesFromLoadOrOutsideLoop(Value v, Operation *&foundLoad);
52+
void run();
53+
};
2854

2955
//===----------------------------------------------------------------------===//
3056
// MMA Pipeline Rewriters
@@ -35,12 +61,6 @@ bool mmav5DominatesTmemLoads(
3561
TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
3662
bool multiBufferred, int numStages);
3763

38-
// Return true if operands of the MMA operation are/are going to be pipelined
39-
// and multibuffered, enabling the MMA operation to be pipelined.
40-
bool mmaHasPipelineableOperands(
41-
MMAv5OpInterface mma, scf::ForOp forOp,
42-
std::function<bool(Operation *)> isLoadPipelineable);
43-
4464
// Return true if the accumulator of an mma in subsequent iterations is either
4565
// independent from the previous iteration (overwritten) or completely reused,
4666
// without read-modify-write.

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ int getCopyVecBytes(RankedTensorType registerTy,
9191
// attribute.
9292
void serializeLatencies(ModuleOp module, DenseMap<Operation *, int> &opLatency);
9393

94+
// Serialize the self latencies of the operations in the loops into the
95+
// self_latency attribute.
96+
void serializeSelfLatencies(ModuleOp module,
97+
DenseMap<Operation *, int> &opSelfLatency);
98+
9499
// Deserialize the latencies of the operations in the loops from the attribute.
95100
DenseMap<Operation *, int> deserializeLatencies(Operation *op);
96101

@@ -107,6 +112,9 @@ Value createAlloc(scf::ForOp forOp, RankedTensorType ty, Location loc,
107112
// Determine if the operation is a TMA load.
108113
bool isTMALoad(Operation *op);
109114

115+
// Determine if the operation can be lowered to an async load.
116+
bool canBeAsyncLoad(Operation *op);
117+
110118
// Look for consecutive wait ops and combine them into a single wait op.
111119
void combineRedundantWaitOps(
112120
llvm::SmallSetVector<gpu::AsyncWaitOp, 8> &waitOps);

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ class CoarseSchedule {
125125

126126
auto find(Operation *op) const { return opToStageAndCluster.find(op); }
127127

128+
// Split the cluster containing op into two clusters, one containing all
129+
// operations before the op and one containing op and all operations after the
130+
// op. Return the cluster containing op and all operations after the op.
131+
Cluster splitClusterBefore(Operation *op, scf::ForOp forOp);
132+
133+
// Check if op a will show up before op b in the final unrolled code.
134+
bool isOpBefore(Operation *a, Operation *b);
135+
136+
// Check if op a is in earlier cluster than op b.
137+
bool isOpInEarlierCluster(Operation *a, Operation *b);
138+
139+
// Check if op a is in the same cluster as op b.
140+
bool isOpInSameCluster(Operation *a, Operation *b);
141+
128142
SmallVector<std::tuple<Operation *, int, Cluster>>
129143
getOpsInOrder(scf::ForOp forOp);
130144
std::vector<std::pair<Operation *, unsigned>>

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4444
"STORE_TMEM_TO_GLOBAL_BYPASS_SMEM",
4545
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
4646
"TRITON_F32_DEFAULT",
47-
"ENABLE_MMA_V5_ATT_PIPELINE",
4847
"TRITON_INTEL_ADVANCED_PATH",
4948
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
5049
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,9 @@ getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
217217
SharedMemorySpaceAttr::get(argType.getContext());
218218
auto CTALayout = getCTALayout(argType.getEncoding());
219219
// No swizzling for scale for now
220-
auto newLayout = SwizzledSharedEncodingAttr::get(argType.getContext(), 1, 1,
221-
1, newOrder, CTALayout);
220+
auto newLayout = NVMMASharedEncodingAttr::get(
221+
argType.getContext(), argType.getShape(), newOrder, CTALayout,
222+
argType.getElementType(), false);
222223
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
223224
newLayout, SharedMemorySpace);
224225
rewriter.setInsertionPointAfterValue(arg);

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,14 +316,10 @@ class UseShmemForScales
316316
return false;
317317

318318
auto sharedEnc =
319-
cast<triton::gpu::SwizzledSharedEncodingAttr>(scaleType.getEncoding());
320-
if (sharedEnc.getMaxPhase() != 1 || sharedEnc.getPerPhase() != 1 ||
321-
sharedEnc.getVec() != 1) {
322-
// For now, we do not expect swizzling to be applied to the scale SMEM.
323-
// This is currently true for non-matmul operand SMEM allocated during
324-
// pipelining.
319+
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(scaleType.getEncoding());
320+
if (!sharedEnc || sharedEnc.getTransposed() || sharedEnc.getFp4Padded() ||
321+
sharedEnc.getSwizzlingByteWidth() != 0)
325322
return false;
326-
}
327323

328324
if (usesTMAload) {
329325
return true;

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -260,37 +260,56 @@ class AssignMMALatencies {
260260
: forOp(forOp), opLatency(opLatency) {};
261261

262262
void run() {
263-
if (!triton::tools::getBoolEnv("ENABLE_MMA_V5_ATT_PIPELINE")) {
264-
int mmav5Count = 0;
265-
for (auto &op : forOp.getBody()->without_terminator()) {
266-
if (isa<ttng::MMAv5OpInterface>(&op)) {
267-
mmav5Count++;
268-
}
269-
}
270-
if (mmav5Count > 1)
271-
return;
272-
}
263+
DenseMap<Operation *, int> mmaSelfLatency;
273264
// Check if the load op (mma operand) is pipelineable.
274-
auto isLoadPipelineable = [&](Operation *op) {
265+
auto isLoadToBePipelined = [&](Operation *op) {
275266
return opLatency.count(op) && opLatency[op] > 0;
276267
};
277268
for (auto &op : forOp.getBody()->without_terminator()) {
278269
// If the acc can not be multibuffered, do not pipeline the uses of
279270
// the MMA to later stages.
280271
if (auto mma = dyn_cast<ttng::MMAv5OpInterface>(&op)) {
281-
if (ttng::mmaHasPipelineableOperands(mma, forOp, isLoadPipelineable) &&
282-
!ttng::hasAccReadModifyWrite(mma, forOp) &&
283-
ttng::isAccMultibufferingPossible(mma, forOp) &&
284-
!getDisallowAccMultiBuffer(forOp)) {
285-
opLatency[&op] = 1;
272+
// Try to push out the wait by one stage even if the operands are not
273+
// pipelineable, but we know where the loads are scheduled, so we can
274+
// place the wait right before the loads.
275+
276+
if (hasSyncDots(forOp)) {
277+
// Skip pipelining MMA in the loops where sync dots are used. This is
278+
// dirty heuristic for performance drops in kernels where we would
279+
// rather want to have last iteration peeled instead of having a full
280+
// iteration of masked operations only to execute single wait.
281+
continue;
282+
}
283+
auto pipeHelper = ttng::MMAv5PipelineableOperandsHelper(
284+
mma, forOp, isLoadToBePipelined);
285+
if (pipeHelper.isPipelineable ||
286+
(pipeHelper.isOperandsStateDetermined &&
287+
!ttng::hasLoadsAfterMMA(mma, forOp))) {
288+
// MMA can be overlapped with itself
289+
mmaSelfLatency[mma] = 1;
290+
if (!ttng::requiresAccMultiBuffering(mma, forOp) ||
291+
(ttng::isAccMultibufferingPossible(mma, forOp) &&
292+
!getDisallowAccMultiBuffer(forOp))) {
293+
// MMA's users can be pushed to the next stage
294+
opLatency[&op] = 1;
295+
}
286296
}
287297
}
288298
}
299+
serializeSelfLatencies(forOp->getParentOfType<ModuleOp>(), mmaSelfLatency);
289300
}
290301

291302
private:
292303
scf::ForOp forOp;
293304
DenseMap<Operation *, int> &opLatency;
305+
306+
bool hasSyncDots(scf::ForOp forOp) {
307+
for (auto &op : forOp.getBody()->without_terminator()) {
308+
if (isa<mlir::triton::DotOp>(op))
309+
return true;
310+
}
311+
return false;
312+
}
294313
};
295314

296315
} // namespace

0 commit comments

Comments
 (0)