Skip to content

Commit 5c5ab9f

Browse files
authored
[WS] Support scalar ops across partition (triton-lang#8061)
We use splat/unsplat as we currently don't support allocation of scalars.
1 parent 0ef5eae commit 5c5ab9f

File tree

2 files changed

+57
-13
lines changed

2 files changed

+57
-13
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#include "mlir/Dialect/SCF/IR/SCF.h"
2+
#include "mlir/IR/Attributes.h"
3+
#include "mlir/IR/BuiltinAttributes.h"
24
#include "mlir/IR/BuiltinOps.h"
35
#include "mlir/Pass/Pass.h"
46
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
58
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
69
#include "triton/Dialect/TritonGPU/Transforms/Partition.h"
710
#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h"
@@ -202,17 +205,12 @@ LogicalResult DependencyRewriter::run() {
202205
llvm::zip(schedule.getPartitions(), partitionUseInfo)) {
203206
// The amount of buffering is based on the longest distance to a user.
204207
for (auto &[output, info] : useInfo) {
205-
// FIXME: No IR support for passing simple scalars through shared
206-
// memory.
207-
auto tensorType = dyn_cast<RankedTensorType>(output.getType());
208-
if (!tensorType) {
209-
return mlir::emitWarning(output.getLoc(),
210-
"FIXME: only tensor SSA dependencies between "
211-
"partitions are supported");
212-
}
208+
b.setLoc(output.getLoc());
209+
ImplicitLocOpBuilder endBuilder(b.getLoc(), loop->getNextNode());
213210

214-
Operation *defOp;
211+
bool isScalar = false;
215212
Value tmp = output;
213+
Operation *defOp;
216214
while (true) {
217215
if (auto arg = dyn_cast<BlockArgument>(tmp)) {
218216
tmp = loop.getBody()->getTerminator()->getOperand(arg.getArgNumber() -
@@ -222,14 +220,31 @@ LogicalResult DependencyRewriter::run() {
222220
defOp = tmp.getDefiningOp();
223221
break;
224222
}
223+
Value val = output;
224+
auto tensorType = dyn_cast<RankedTensorType>(output.getType());
225+
if (!tensorType) {
226+
isScalar = true;
227+
b.setInsertionPointAfterValue(output);
228+
auto mod = output.getParentRegion()->getParentOfType<ModuleOp>();
229+
auto nWarps = lookupNumWarps(mod);
230+
auto threadsPerWarp =
231+
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
232+
int CTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
233+
Attribute encoding = getDefaultBlockedEncoding(
234+
b.getContext(), {1}, nWarps, threadsPerWarp, CTAs);
235+
tensorType = RankedTensorType::get({1}, output.getType(), encoding);
236+
StageCluster srcStageCluster = getStageCluster(defOp);
237+
238+
defOp = b.createInto<triton::SplatOp>(partition, srcStageCluster,
239+
tensorType, output);
240+
val = defOp->getResult(0);
241+
}
225242

226243
// Buffer the value based on the greatest distance to a consumer
227244
// partition.
228245
int maxDistance = info.getMaxUseDistance(partition);
229246

230247
// Allocate buffers for the value and its associated barriers.
231-
b.setLoc(output.getLoc());
232-
ImplicitLocOpBuilder endBuilder(b.getLoc(), loop->getNextNode());
233248
AsyncRef aref = allocateAsyncValue(tensorType, maxDistance);
234249

235250
unsigned numConsumers = info.consumers.size();
@@ -249,20 +264,24 @@ LogicalResult DependencyRewriter::run() {
249264
// partition with it.
250265
Value value = b.createInto<LocalLoadOp>(*usePartition, sinkSrcCluster,
251266
tensorType, view);
267+
if (isScalar) {
268+
value = b.createInto<triton::UnsplatOp>(*usePartition, sinkSrcCluster,
269+
value);
270+
}
252271
for (OpOperand *use : uses)
253272
use->set(value);
254273
exitOp(b);
255274
}
256275

257276
// Set up production of the value
258-
if (isa<BlockArgument>(output))
277+
if (isa<BlockArgument>(val))
259278
b.setInsertionPointToStart(loop.getBody());
260279
else
261280
b.setInsertionPointAfter(defOp);
262281

263282
StageCluster srcStageCluster = getStageCluster(defOp);
264283
auto [view, exitOp] = aref.putView(b, partition, srcStageCluster);
265-
b.createInto<LocalStoreOp>(partition, srcStageCluster, output, view);
284+
b.createInto<LocalStoreOp>(partition, srcStageCluster, val, view);
266285
exitOp(b);
267286
}
268287
}

test/TritonGPU/rewrite-partition-dependencies.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,31 @@ tt.func @no_def_op(%lb: i32, %ub: i32, %step: i32) {
320320
tt.return
321321
}
322322

323+
// CHECK-LABEL: @scalar_consumers
324+
tt.func @scalar_consumers(%lb: i32, %ub: i32, %step: i32) {
325+
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
326+
// CHECK-NEXT: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
327+
// CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[ABUF]]
328+
scf.for %i = %lb to %ub step %step iter_args() -> () : i32 {
329+
%0 = "op_a"() {ttg.partition = 0} : () -> i32
330+
// CHECK: [[VAL:%.*]] = "op_a"
331+
// CHECK-NEXT: [[VAL_TENSOR:%.*]] = tt.splat [[VAL]] {ttg.partition = 0 : i32} : i32 -> tensor<1xi32, #blocked>
332+
// CHECK-NEXT: [[BUF:%.*]], [[TOKEN:%.*]] = nvws.aref.put.enter [[AREF]][[[C0]], [[C0]]] {ttg.partition = 0 : i32}
333+
// CHECK-NEXT: ttg.local_store [[VAL_TENSOR]], [[BUF]] {ttg.partition = 0 : i32}
334+
// CHECK-NEXT: nvws.aref.put.exit [[AREF]][[[C0]]], [[TOKEN]] [#nvws.async_op<none>] {ttg.partition = 0 : i32}
335+
336+
"op_b"(%0) {ttg.partition = 1} : (i32) -> ()
337+
// CHECK-NEXT: [[BUF:%.*]], [[TOKEN:%.*]] = nvws.aref.get.enter [[AREF]][[[C0]], [[C0]]] {ttg.partition = 1 : i32}
338+
// CHECK-NEXT: [[VAL:%.*]] = ttg.local_load [[BUF]] {ttg.partition = 1 : i32}
339+
// CHECK-NEXT: [[VAL_SCALAR:%.*]] = tt.unsplat [[VAL]] {ttg.partition = 1 : i32} : tensor<1xi32, #blocked>
340+
// CHECK-NEXT: nvws.aref.get.exit [[AREF]][[[C0]]], [[TOKEN]] [#nvws.async_op<none>] {ttg.partition = 1 : i32}
341+
// CHECK-NEXT: "op_b"([[VAL_SCALAR]])
342+
343+
} {ttg.partition.stages = [0, 2], ttg.warp_specialize.tag = 0 : i32}
344+
tt.return
345+
}
346+
347+
323348
}
324349

325350
// -----

0 commit comments

Comments
 (0)