Skip to content

Commit 3752705

Browse files
Merge OpenAI commit ddada27 (#5069)
This PR change the Triton base from 72ec661 to ddada27 (Sep 4). Pass rate: 98.6%
2 parents 79841bd + c61601d commit 3752705

File tree

21 files changed

+284
-382
lines changed

21 files changed

+284
-382
lines changed

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ class CombineDotAddPattern : public mlir::OpRewritePattern<OpTy> {
252252
}
253253
if (!isZero(dotOp.getC()))
254254
return failure();
255+
if constexpr (std::is_same_v<OpTy, arith::AddFOp>) {
256+
if (dotOp.getMaxNumImpreciseAcc() != 0) {
257+
return failure();
258+
}
259+
}
255260
rewriter.modifyOpInPlace(dotOp, [&] {
256261
dotOp.getCMutable().assign(isDotLHS ? addOp.getRhs() : addOp.getLhs());
257262
dotOp->moveBefore(addOp);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,19 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
692692
int mIndex = 0 + hasBatchDim;
693693

694694
int32_t kWidth = dotMfmaLayout.getKWidth();
695-
auto kDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
695+
auto nonKDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
696696

697697
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
698698
auto tilesPerWarp = mfmaLayout.getTilesPerWarp();
699-
auto tilePerWarpNonK = tilesPerWarp[kDimIndex];
699+
auto tilePerWarpNonK = tilesPerWarp[nonKDimIndex];
700700

701701
auto mDim = mfmaLayout.getMDim();
702702
auto nDim = mfmaLayout.getNDim();
703703
auto opIdx = dotMfmaLayout.getOpIdx();
704704
auto nonKDim = opIdx == 0 ? mDim : nDim;
705705
constexpr int warpSize = 64;
706706

707+
auto kDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
707708
int32_t kSize = shape[kDimIndex];
708709

709710
MLIRContext *ctx = dotMfmaLayout.getContext();

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,9 @@ static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop,
10061006
newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits());
10071007
newInnerLoop.erase();
10081008

1009+
// Clear up the warp specialization attributes for the specialized loop.
1010+
newLoop->removeAttr(kWarpSpecializeAttrName);
1011+
10091012
// Move the loop nest into the `else` branch.
10101013
outerLoop.replaceAllUsesWith(ifOp.getResults());
10111014
Block *block = b.createBlock(&ifOp.getElseRegion());

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#include "mlir/Dialect/SCF/IR/SCF.h"
2+
#include "mlir/IR/Attributes.h"
3+
#include "mlir/IR/BuiltinAttributes.h"
24
#include "mlir/IR/BuiltinOps.h"
35
#include "mlir/Pass/Pass.h"
46
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
58
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
69
#include "triton/Dialect/TritonGPU/Transforms/Partition.h"
710
#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h"
@@ -202,17 +205,12 @@ LogicalResult DependencyRewriter::run() {
202205
llvm::zip(schedule.getPartitions(), partitionUseInfo)) {
203206
// The amount of buffering is based on the longest distance to a user.
204207
for (auto &[output, info] : useInfo) {
205-
// FIXME: No IR support for passing simple scalars through shared
206-
// memory.
207-
auto tensorType = dyn_cast<RankedTensorType>(output.getType());
208-
if (!tensorType) {
209-
return mlir::emitWarning(output.getLoc(),
210-
"FIXME: only tensor SSA dependencies between "
211-
"partitions are supported");
212-
}
208+
b.setLoc(output.getLoc());
209+
ImplicitLocOpBuilder endBuilder(b.getLoc(), loop->getNextNode());
213210

214-
Operation *defOp;
211+
bool isScalar = false;
215212
Value tmp = output;
213+
Operation *defOp;
216214
while (true) {
217215
if (auto arg = dyn_cast<BlockArgument>(tmp)) {
218216
tmp = loop.getBody()->getTerminator()->getOperand(arg.getArgNumber() -
@@ -222,14 +220,31 @@ LogicalResult DependencyRewriter::run() {
222220
defOp = tmp.getDefiningOp();
223221
break;
224222
}
223+
Value val = output;
224+
auto tensorType = dyn_cast<RankedTensorType>(output.getType());
225+
if (!tensorType) {
226+
isScalar = true;
227+
b.setInsertionPointAfterValue(output);
228+
auto mod = output.getParentRegion()->getParentOfType<ModuleOp>();
229+
auto nWarps = lookupNumWarps(mod);
230+
auto threadsPerWarp =
231+
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
232+
int CTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
233+
Attribute encoding = getDefaultBlockedEncoding(
234+
b.getContext(), {1}, nWarps, threadsPerWarp, CTAs);
235+
tensorType = RankedTensorType::get({1}, output.getType(), encoding);
236+
StageCluster srcStageCluster = getStageCluster(defOp);
237+
238+
defOp = b.createInto<triton::SplatOp>(partition, srcStageCluster,
239+
tensorType, output);
240+
val = defOp->getResult(0);
241+
}
225242

226243
// Buffer the value based on the greatest distance to a consumer
227244
// partition.
228245
int maxDistance = info.getMaxUseDistance(partition);
229246

230247
// Allocate buffers for the value and its associated barriers.
231-
b.setLoc(output.getLoc());
232-
ImplicitLocOpBuilder endBuilder(b.getLoc(), loop->getNextNode());
233248
AsyncRef aref = allocateAsyncValue(tensorType, maxDistance);
234249

235250
unsigned numConsumers = info.consumers.size();
@@ -249,20 +264,24 @@ LogicalResult DependencyRewriter::run() {
249264
// partition with it.
250265
Value value = b.createInto<LocalLoadOp>(*usePartition, sinkSrcCluster,
251266
tensorType, view);
267+
if (isScalar) {
268+
value = b.createInto<triton::UnsplatOp>(*usePartition, sinkSrcCluster,
269+
value);
270+
}
252271
for (OpOperand *use : uses)
253272
use->set(value);
254273
exitOp(b);
255274
}
256275

257276
// Set up production of the value
258-
if (isa<BlockArgument>(output))
277+
if (isa<BlockArgument>(val))
259278
b.setInsertionPointToStart(loop.getBody());
260279
else
261280
b.setInsertionPointAfter(defOp);
262281

263282
StageCluster srcStageCluster = getStageCluster(defOp);
264283
auto [view, exitOp] = aref.putView(b, partition, srcStageCluster);
265-
b.createInto<LocalStoreOp>(partition, srcStageCluster, output, view);
284+
b.createInto<LocalStoreOp>(partition, srcStageCluster, val, view);
266285
exitOp(b);
267286
}
268287
}

python/src/ir.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,8 @@ void init_triton_ir(py::module &&m) {
554554
})
555555
.def("verify",
556556
[](OpState &self) -> bool {
557+
TritonSourceMgrDiagnosticHandler handler =
558+
setupTritonDiagnosticHandler(self.getContext());
557559
return succeeded(verify(self.getOperation()));
558560
})
559561
.def("get_operation", [](OpState &self) { return self.getOperation(); });
@@ -700,12 +702,7 @@ void init_triton_ir(py::module &&m) {
700702
.def("walk",
701703
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
702704
self.walk(fn);
703-
})
704-
.def("verify_with_diagnostics", [](ModuleOp &self) {
705-
TritonSourceMgrDiagnosticHandler handler =
706-
setupTritonDiagnosticHandler(self.getContext());
707-
return succeeded(verify(self.getOperation()));
708-
});
705+
});
709706

710707
m.def("make_attr", [](const std::vector<int> &values, MLIRContext &context) {
711708
return mlir::cast<Attribute>(DenseIntElementsAttr::get(

python/test/gluon/test_lowerings.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,3 +1209,48 @@ def test_gather_layouts(axis, src_layout, index_layout, src_shape, idx_shape, de
12091209

12101210
torch.testing.assert_close(out, ref, rtol=0, atol=0)
12111211
assert ("nvvm.shfl.sync.idx" in obj.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in obj.asm["llir"])
1212+
1213+
1214+
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size",
1215+
[[128, 128, 64, 64], [128, 128, 64, 32], [128, 64, 64, 32], [256, 128, 64, 64]])
1216+
def test_memdesc_subslice(M, N, M_tile_size, N_tile_size, device):
1217+
if M % M_tile_size != 0 or N % N_tile_size != 0:
1218+
pytest.skip(f"Shape size ({M}, {N}) must be divisible by tile size ({M_tile_size}, {N_tile_size})")
1219+
1220+
num_rows_per_warp = THREADS_PER_WARP // 4
1221+
blocked_layout = ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[num_rows_per_warp, 4],
1222+
warps_per_cta=[4, 1], order=[1, 0])
1223+
shared_layout = ttgl.SwizzledSharedLayout(vec=8, per_phase=1, max_phase=8, order=[1, 0])
1224+
1225+
@gluon.jit
1226+
def kernel(
1227+
out,
1228+
M: ttgl.constexpr,
1229+
N: ttgl.constexpr,
1230+
BLOCK_SIZE_M: ttgl.constexpr,
1231+
BLOCK_SIZE_N: ttgl.constexpr,
1232+
blocked_layout: ttgl.constexpr,
1233+
shared_layout: ttgl.constexpr,
1234+
):
1235+
offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, blocked_layout))[:, None]
1236+
offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, blocked_layout))[None, :]
1237+
vals = ttgl.load(out + offs_m * N + offs_n)
1238+
1239+
smem: ttgl.shared_memory_descriptor = ttgl.allocate_shared_memory(vals.dtype, (M, N), shared_layout, value=vals)
1240+
for i in ttgl.static_range(M // BLOCK_SIZE_M):
1241+
for j in ttgl.static_range(N // BLOCK_SIZE_N):
1242+
tile = smem.slice(i * BLOCK_SIZE_M, BLOCK_SIZE_M, dim=0).slice(j * BLOCK_SIZE_N, BLOCK_SIZE_N, dim=1)
1243+
tile_vals = tile.load(blocked_layout)
1244+
tile_offs_m = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked_layout))[:, None]
1245+
tile_offs_n = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked_layout))[None, :]
1246+
linear_idx = tile_offs_m * N + tile_offs_n + i * BLOCK_SIZE_M * N + j * BLOCK_SIZE_N
1247+
tile.store(linear_idx + tile_vals)
1248+
1249+
vals = smem.load(blocked_layout)
1250+
ttgl.store(out + offs_m * N + offs_n, vals)
1251+
1252+
out = torch.zeros((M, N), device=device, dtype=torch.float16)
1253+
kernel[(1, )](out, M, N, M_tile_size, N_tile_size, blocked_layout, shared_layout)
1254+
1255+
out_ref = torch.arange(0, M * N, device=device).reshape((M, N)).to(torch.float16)
1256+
torch.testing.assert_close(out, out_ref, rtol=0, atol=0)

0 commit comments

Comments
 (0)