Skip to content

Commit b665d5a

Browse files
Merge OpenAI Triton commit c6ee626 (#4188)
This PR change the Triton base from 3c8cdbf to c6ee626 (May 8). Pass rate: 97.77%
2 parents 108e8a1 + a35e8b3 commit b665d5a

File tree

22 files changed

+1328
-255
lines changed

22 files changed

+1328
-255
lines changed

bench/triton_bench/swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(ctx, a, alpha, precision_config, routing_data):
6969
n_tokens,
7070
BLOCK_M=BLOCK_M,
7171
BLOCK_N=BLOCK_N,
72-
EVEN_N=(N // 2) % 2 == 0,
72+
EVEN_N=(N // 2) % BLOCK_N == 0,
7373
M_BLOCKS=M_BLOCKS,
7474
N_BLOCKS=N_BLOCKS,
7575
flexpoint_saturate_inf=flex_ctx.saturate_inf,

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ class WarpSchedule {
6161

6262
// Create a new partition with a stage.
6363
Partition *addPartition(unsigned stage);
64-
// Give each partition a new index and order. The indices must be unique.
65-
void reorderPartitions(ArrayRef<unsigned> order);
6664
// Update the op to partition mapping.
6765
void updatePartitions();
6866

@@ -74,15 +72,25 @@ class WarpSchedule {
7472
Partition *getPartition(unsigned idx);
7573
// Get the partition at the index.
7674
const Partition *getPartition(unsigned idx) const;
75+
// Insert an operation into a partition.
76+
void insert(Partition *partition, Operation *op);
7777
// Return an iterator range over the partitions.
7878
auto getPartitions() { return llvm::make_pointee_range(partitions); }
7979
// Return an iterator range over the partitions.
8080
auto getPartitions() const { return llvm::make_pointee_range(partitions); }
81+
// Get the number of partitions.
82+
unsigned getNumPartitions() const { return partitions.size(); }
8183
// Get the root partition.
8284
Partition *getRootPartition() { return rootPartition.get(); }
8385
// Get the root partition.
8486
const Partition *getRootPartition() const { return rootPartition.get(); }
8587

88+
// Return true if an operation is assigned to a partition.
89+
bool isScheduled(Operation *op) const;
90+
// Schedule an operation to a partition if it is not already scheduled. Return
91+
// true if the operation was scheduled.
92+
bool trySchedule(Partition *partition, Operation *op);
93+
8694
// Deserialize a warp schedule from an `scf.for` op using the attributes
8795
// tagged on operations in its body.
8896
static FailureOr<WarpSchedule> deserialize(scf::ForOp loop);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1919
static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
22+
static const char *kAssignedStageAttrName = "ttg.assigned_stage";
2223

2324
//===----------------------------------------------------------------------===//
2425
// Hoisting Utilities

lib/Analysis/Membar.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "triton/Analysis/Membar.h"
22
#include "triton/Analysis/Alias.h"
33
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
45

56
#include "mlir/Dialect/Func/IR/FuncOps.h"
67
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -216,6 +217,13 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
216217
}
217218
}
218219
}
220+
// If this op is may be signalling other threads asynchronously, make sure
221+
// all shared memory transactions are complete beforehand.
222+
if (isa<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
223+
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
224+
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
225+
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
226+
}
219227
scratchBufferId = allocation->getBufferId(op);
220228
}
221229

lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ struct TMEMLoadOpPattern : public OpConversionPattern<ttng::TMEMLoadOp> {
4444
RankedTensorType type = getTMEMTensorLayout(
4545
typeConverter, op.getType(), op.getSrc().getType(), lookupNumWarps(op));
4646
rewriter.modifyOpInPlace(op, [&] { op.getResult().setType(type); });
47+
Type resultType = getTypeConverter()->convertType(op.getType());
48+
rewriter.setInsertionPointAfter(op);
49+
auto cvt = rewriter.create<ConvertLayoutOp>(op.getLoc(), resultType,
50+
op.getResult());
51+
rewriter.replaceAllUsesExcept(op.getResult(), cvt, cvt);
4752
return success();
4853
}
4954
};

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,13 @@ void tt::CoarseSchedule::serialize(scf::ForOp &forOp) {
189189
for (auto [op, stage, cluster] : getOpsInOrder(forOp)) {
190190
setStageCluster(op, stage, *cluster);
191191
}
192+
193+
Builder b(forOp.getContext());
194+
int maxStages = numStages - 1;
195+
if (auto maxStageAttr = tryGetMaxStage(forOp))
196+
maxStages = std::max(maxStages, *maxStageAttr);
192197
forOp->setAttr(mlir::triton::kScheduledMaxStageAttrName,
193-
IntegerAttr::get(IntegerType::get(forOp.getContext(), 32),
194-
numStages - 1));
198+
b.getI32IntegerAttr(maxStages));
195199
}
196200

197201
// Create a CoarseSchedule based on forOp's <stage, cluster>.

lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ bool hasLatenciesAssigned(scf::ForOp forOp,
5959
for (auto &op : forOp.getBody()->without_terminator()) {
6060
if (opLatency.count(&op))
6161
return true;
62+
if (op.getAttr(kAssignedStageAttrName))
63+
return true;
6264
}
6365
return false;
6466
}
@@ -70,12 +72,15 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
7072
auto terminator = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
7173
// Determine all operations that have a non-zero latency
7274
SmallVector<Operation *> latOps;
75+
SmallVector<Operation *> stagedOps;
7376
for (auto &op : forOp.getBody()->without_terminator()) {
7477
if (opLatency.count(&op))
7578
latOps.push_back(&op);
79+
if (op.getAttr(kAssignedStageAttrName))
80+
stagedOps.push_back(&op);
7681
}
7782
// If no latency ops, nothing to schedule
78-
if (latOps.empty())
83+
if (latOps.empty() && stagedOps.empty())
7984
return CoarseSchedule(0);
8085

8186
DominanceInfo domInfo(forOp);
@@ -123,6 +128,11 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
123128
opToStage[op] = maxDistance - dist;
124129
}
125130

131+
for (Operation *op : stagedOps) {
132+
auto stageAttr = op->getAttrOfType<IntegerAttr>(kAssignedStageAttrName);
133+
opToStage[op] = stageAttr.getInt();
134+
}
135+
126136
auto stages = llvm::make_second_range(opToStage);
127137
int maxStage = *llvm::max_element(stages);
128138
CoarseSchedule schedule(maxStage + 1);

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ static void removeAttributes(ModuleOp moduleOp) {
7373
op->removeAttr(mlir::triton::kLoopStageAttrName);
7474
op->removeAttr(mlir::triton::kLoopClusterAttrName);
7575
op->removeAttr(mlir::triton::kScheduledMaxStageAttrName);
76+
op->removeAttr(mlir::triton::kAssignedStageAttrName);
7677
});
7778
}
7879

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,30 +1402,35 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14021402
// Save the operand to replace / delete later (avoid iterator invalidation).
14031403
// TODO: can we use an early_inc iterator?
14041404
for (OpOperand &use : oldUse->getUses()) {
1405+
// Propagate through `ttg.warp_specialize`.
1406+
if (auto wsOp = dyn_cast<ttg::WarpSpecializeOp>(use.getOwner())) {
1407+
for (Region *region : wsOp.getPartitionRegions())
1408+
region->getArgument(use.getOperandNumber()).setType(val.getType());
1409+
}
1410+
14051411
// Non-subview/trans ops will be replaced by `val`.
1406-
if (!isa<triton::gpu::MemDescTransOp, triton::gpu::MemDescSubviewOp>(
1407-
use.getOwner())) {
1412+
if (!isa<ttg::MemDescTransOp, ttg::MemDescSubviewOp>(use.getOwner())) {
14081413
operandsToReplace.push_back(&use);
14091414
continue;
14101415
}
1416+
14111417
Operation *user = use.getOwner();
14121418
// `subview(old_op)` is replaced by a new `subview(val)`.
14131419
OpBuilder::InsertionGuard g(builder);
14141420
builder.setInsertionPoint(user);
14151421
Value newVal;
1416-
if (auto subview = dyn_cast<triton::gpu::MemDescSubviewOp>(user)) {
1417-
triton::gpu::MemDescType oldType = subview.getType();
1418-
bool isMutable =
1419-
cast<triton::gpu::MemDescType>(val.getType()).getMutableMemory();
1420-
Type newDstType = triton::gpu::MemDescType::get(
1422+
if (auto subview = dyn_cast<ttg::MemDescSubviewOp>(user)) {
1423+
ttg::MemDescType oldType = subview.getType();
1424+
bool isMutable = cast<ttg::MemDescType>(val.getType()).getMutableMemory();
1425+
Type newDstType = ttg::MemDescType::get(
14211426
oldType.getShape(), oldType.getElementType(), oldType.getEncoding(),
14221427
oldType.getMemorySpace(), isMutable);
1423-
newVal = builder.create<triton::gpu::MemDescSubviewOp>(
1428+
newVal = builder.create<ttg::MemDescSubviewOp>(
14241429
subview.getLoc(), newDstType, val, subview.getOffsets());
14251430
newVal.getDefiningOp()->setAttrs(user->getAttrs());
1426-
} else if (auto trans = dyn_cast<triton::gpu::MemDescTransOp>(user)) {
1427-
newVal = builder.create<triton::gpu::MemDescTransOp>(trans.getLoc(), val,
1428-
trans.getOrder());
1431+
} else if (auto trans = dyn_cast<ttg::MemDescTransOp>(user)) {
1432+
newVal = builder.create<ttg::MemDescTransOp>(trans.getLoc(), val,
1433+
trans.getOrder());
14291434
newVal.getDefiningOp()->setAttrs(user->getAttrs());
14301435
}
14311436
assert(newVal);

0 commit comments

Comments
 (0)