Skip to content

Commit 4a8277b

Browse files
authored
[PIPELINER] Support pipelining scalar loads (#7498)
So far scalar loads were not being pipelined but they can still cause latency problems. Extend support for scalar loads, to keep things simpler we convert those into tensor<1> loads during loop lowering. This also introduces a new `unsplat` op to make the conversion from tensor to scalar simple.
1 parent be6ef7a commit 4a8277b

File tree

9 files changed

+170
-67
lines changed

9 files changed

+170
-67
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,16 @@ def TT_SplatOp : TT_Op<"splat", [Pure,
432432
let hasFolder = 1;
433433
}
434434

435+
def TT_UnsplatOp : TT_Op<"unsplat", [Pure,
436+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
437+
let summary = "convert a tensor with a single element to a scalar";
438+
let arguments = (ins TT_Tensor:$src);
439+
let results = (outs TT_Type:$result);
440+
441+
let assemblyFormat = "$src attr-dict `:` type($src)";
442+
let hasVerifier = 1;
443+
}
444+
435445
def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure,
436446
DeclareOpInterfaceMethods<InferTypeOpInterface>,
437447
SameOperandsAndResultElementType]> {

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
6060
return success();
6161
}
6262
};
63+
64+
struct UnsplatOpConversion : public ConvertOpToLLVMPattern<triton::UnsplatOp> {
65+
using ConvertOpToLLVMPattern<triton::UnsplatOp>::ConvertOpToLLVMPattern;
66+
LogicalResult matchAndRewrite(triton::UnsplatOp op, OpAdaptor adaptor,
67+
ConversionPatternRewriter &rewriter) const {
68+
auto loc = op->getLoc();
69+
auto scrVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
70+
rewriter.replaceOp(op, scrVals[0]);
71+
return success();
72+
}
73+
};
74+
6375
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
6476
// the logic is the same as triton::SplatOp, so the underlying implementation
6577
// is reused.
@@ -550,6 +562,7 @@ void mlir::triton::populateViewOpToLLVMPatterns(
550562
patterns.add<ReshapeOpConversion>(typeConverter, benefit);
551563
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
552564
patterns.add<SplatOpConversion>(typeConverter, benefit);
565+
patterns.add<UnsplatOpConversion>(typeConverter, benefit);
553566
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
554567
patterns.add<ArithConstantArrayOpConversion>(typeConverter, benefit);
555568
patterns.add<CatOpConversion>(typeConverter, benefit);

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,24 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
606606
return ret;
607607
}
608608

609+
//-- UnsplatOp --
610+
LogicalResult UnsplatOp::verify() {
611+
auto srcShape = getSrc().getType().getShape();
612+
if (product(srcShape) != 1) {
613+
return emitError("source tensor must have exactly one element");
614+
}
615+
return success();
616+
}
617+
618+
LogicalResult UnsplatOp::inferReturnTypes(
619+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
620+
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
621+
SmallVectorImpl<Type> &inferredReturnTypes) {
622+
auto dstTy = cast<RankedTensorType>(operands[0].getType()).getElementType();
623+
inferredReturnTypes.push_back(dstTy);
624+
return success();
625+
}
626+
609627
//-- ExpandDimsOp --
610628
LogicalResult ExpandDimsOp::inferReturnTypes(
611629
MLIRContext *context, std::optional<Location> loc, ValueRange operands,

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,49 @@ struct LoadGroupInfo {
334334
bool hasTMALoad = false;
335335
};
336336

337+
// Convert a scalar load to a load of a tensor of shape <1>.
338+
void convertScalarToTensorLoad(Operation *op, CoarseSchedule &schedule,
339+
scf::ForOp forOp) {
340+
auto scalarLoad = cast<tt::LoadOp>(op);
341+
Type scalarTy = scalarLoad.getType();
342+
OpBuilderForStage builder(op->getLoc(), op, schedule);
343+
builder.setInsertionPoint(op);
344+
MLIRContext *ctx = op->getContext();
345+
auto nWarps = lookupNumWarps(op);
346+
ModuleOp mod = forOp->getParentOfType<ModuleOp>();
347+
auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
348+
auto numCTAs = TritonGPUDialect::getNumCTAs(mod);
349+
auto blockedEnc =
350+
getDefaultBlockedEncoding(ctx, {1}, nWarps, threadsPerWarp, numCTAs);
351+
auto newPtrTy =
352+
RankedTensorType::get({1}, scalarLoad.getPtr().getType(), blockedEnc);
353+
auto newPtr =
354+
builder.create<tt::SplatOp>(op->getLoc(), newPtrTy, scalarLoad.getPtr());
355+
scalarLoad.getPtrMutable().assign(newPtr);
356+
if (scalarLoad.getMask()) {
357+
auto newMaskTy =
358+
RankedTensorType::get({1}, scalarLoad.getMask().getType(), blockedEnc);
359+
auto newMask = builder.create<tt::SplatOp>(op->getLoc(), newMaskTy,
360+
scalarLoad.getMask());
361+
scalarLoad.getMaskMutable().assign(newMask);
362+
}
363+
if (scalarLoad.getOther()) {
364+
auto newOtherTy =
365+
RankedTensorType::get({1}, scalarLoad.getOther().getType(), blockedEnc);
366+
auto newOther = builder.create<tt::SplatOp>(op->getLoc(), newOtherTy,
367+
scalarLoad.getOther());
368+
scalarLoad.getOtherMutable().assign(newOther);
369+
}
370+
auto newDstTy = RankedTensorType::get({1}, scalarLoad.getType(), blockedEnc);
371+
scalarLoad.getResult().setType(newDstTy);
372+
builder.setInsertionPointAfter(op);
373+
Operation *firstUse = getFirstUseOfPipelinedOp({op}, forOp, schedule);
374+
builder.setStageCluster(schedule[firstUse]);
375+
Operation *unsplat = builder.create<tt::UnsplatOp>(op->getLoc(), scalarTy,
376+
scalarLoad.getResult());
377+
scalarLoad.getResult().replaceAllUsesExcept(unsplat->getResult(0), unsplat);
378+
}
379+
337380
void createTMABarrierAndWait(
338381
scf::ForOp forOp, llvm::MapVector<Operation *, AsyncLoad> &asyncLoads,
339382
const llvm::MapVector<int, LoadGroupInfo> &loadGroups,
@@ -446,25 +489,39 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
446489
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
447490
llvm::MapVector<Operation *, AsyncLoad> asyncLoads;
448491
llvm::MapVector<int, LoadGroupInfo> loadGroups;
492+
llvm::SmallVector<Operation *> scalarLoads;
449493
// Only visit the top level ops, we do not support pipelining conditional
450494
// loads for now
451495
for (auto &op : forOp.getBody()->without_terminator()) {
452496
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
453497
int stageDiff = getDefUseStageDiff(&op, forOp, schedule);
454-
if (stageDiff == 0 || !isa<RankedTensorType>(op.getResultTypes()[0])) {
455-
// Don't care about non-pipelined loads. Don't use async loads for
456-
// scalar values.
498+
if (stageDiff == 0) {
499+
// Don't care about non-pipelined loads. Scalar loads will be converted
500+
// to tensor loads if they are pipelined.
457501
continue;
458502
}
459-
SharedEncodingTrait sharedEncoding = getSharedEncoding(&op);
460-
// Do not create async loads for small loads (cp.async requires at least 4
461-
// bytes)
462-
bool canUseAsyncCp =
463-
isa<tt::LoadOp>(op) &&
464-
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
465-
int copyVecBytes = getCopyVecBytes(
466-
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
467-
canUseAsyncCp &= copyVecBytes >= 4;
503+
SharedEncodingTrait sharedEncoding;
504+
bool canUseAsyncCp = false;
505+
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
506+
canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32;
507+
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
508+
forOp.getContext(), 1, 1, 1, {0},
509+
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
510+
if (canUseAsyncCp) {
511+
scalarLoads.push_back(&op);
512+
}
513+
} else {
514+
sharedEncoding = getSharedEncoding(&op);
515+
// Do not create async loads for small loads (cp.async requires at least
516+
// 4 bytes)
517+
canUseAsyncCp =
518+
isa<tt::LoadOp>(op) &&
519+
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
520+
int copyVecBytes = getCopyVecBytes(
521+
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
522+
523+
canUseAsyncCp &= copyVecBytes >= 4;
524+
}
468525
if (canUseAsyncCp || isTMALoad(&op)) {
469526
if (loadRequiresAdditionalBuffer(&op)) {
470527
// Allocate additional buffer required by the wgmma pipelining.
@@ -486,6 +543,11 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
486543
}
487544
}
488545

546+
// Convert scalar loads to be able to use async copy.
547+
for (auto op : scalarLoads) {
548+
convertScalarToTensorLoad(op, schedule, forOp);
549+
}
550+
489551
if (asyncLoads.empty())
490552
return forOp;
491553

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,15 @@ bool mlir::triton::canBeConvertedToAsyncLoad(
330330
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
331331

332332
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
333-
if (!tensorTy)
334-
return false;
335-
auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
336-
unsigned width = vec * ty.getIntOrFloatBitWidth();
333+
unsigned width = 0;
334+
if (tensorTy) {
335+
auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
336+
width = vec * ty.getIntOrFloatBitWidth();
337+
} else {
338+
width = cast<tt::PointerType>(ptr.getType())
339+
.getPointeeType()
340+
.getIntOrFloatBitWidth();
341+
}
337342

338343
// We do not pipeline all loads for the following reasons:
339344
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16.

test/Triton/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
538538
tt.return %result : tensor<128x128xf32, #blocked>
539539
}
540540
}
541+
542+
// -----
543+
544+
tt.func @unsplat_invalid(%arg0: tensor<128xf32>) {
545+
// expected-error @below {{source tensor must have exactly one element}}
546+
%0 = tt.unsplat %arg0 : tensor<128xf32>
547+
tt.return
548+
}

test/Triton/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,10 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc<tensor<1x128xbf16>>, %arg1: tensor<32
278278
tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<tensor<1x128xbf16>>, tensor<32xi32>, i32, tensor<32x128xbf16>
279279
tt.return
280280
}
281+
282+
// CHECK-LABEL: @unsplat
283+
tt.func @unsplat(%arg0: tensor<1x1xf32>) -> f32 {
284+
// CHECK-NEXT: tt.unsplat %{{.+}} : tensor<1x1xf32>
285+
%0 = tt.unsplat %arg0 : tensor<1x1xf32>
286+
tt.return %0 : f32
287+
}

test/TritonGPU/loop-pipeline.mlir

Lines changed: 10 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,12 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
393393
// CHECK: ttg.async_copy_global_to_local
394394
// CHECK: ttg.async_commit_group
395395
// CHECK: scf.for
396-
// CHECK: ttg.async_wait {{.*}} {num = 2 : i32}
396+
// CHECK: ttg.async_wait {{.*}} {num = 1 : i32}
397397
// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}}
398398
// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]]
399-
// CHECK: %[[IND_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}}
399+
// CHECK: ttg.async_wait {{.*}} {num = 1 : i32}
400+
// CHECK: %[[IND_BUFFER_0_T:.*]] = ttg.local_load
401+
// CHECK: %[[IND_BUFFER_0:.*]] = tt.unsplat %[[IND_BUFFER_0_T]] : tensor<1xi64
400402
// CHECK: %[[IND_BUFFER_1:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_0]]
401403
// CHECK: %[[IND_BUFFER_2:.*]] = tt.splat %[[IND_BUFFER_1]]
402404
// CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]]
@@ -406,9 +408,9 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
406408
// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc
407409
// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc
408410
// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}}
411+
// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] {amd.pipeliner_part = "prologue"}
409412
// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]]
410413
// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] {amd.pipeliner_part = "prologue"}
411-
// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]]
412414
// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]]
413415
// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]]
414416
// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]]
@@ -418,29 +420,14 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
418420
// AMD: ttg.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_11]]
419421
// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}]
420422
// AMD: ttg.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_12]]
421-
// AMD: %[[CMPI_13:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}}
422-
// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %{{.*}}
423-
// AMD: %[[ADDPTR_15:.*]] = tt.addptr %{{.*}}, %{{.*}}
424-
// AMD: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_13]]
425-
// AMD: %[[LOAD_17:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_16]] {amd.pipeliner_part = "prologue"}
426-
// AMD: %[[LOAD_18:.*]] = tt.load %[[ADDPTR_15]], %[[CMPI_13]]
427-
// AMD: %[[MULI_19:.*]] = arith.muli %{{.*}}, %[[LOAD_18]]
428-
// AMD: %[[SPLAT_20:.*]] = tt.splat %[[MULI_19]]
429-
// AMD: %[[ADDPTR_21:.*]] = tt.addptr %{{.*}}, %[[SPLAT_20]]
430-
// AMD: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_13]]
431-
// AMD: %[[LOAD_23:.*]] = tt.load %[[ADDPTR_21]], %[[SPLAT_22]] {amd.pipeliner_part = "prologue"}
432-
// AMD: %[[MEMDESC_SUBVIEW_24:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}]
433-
// AMD: ttg.local_store %[[LOAD_17]], %[[MEMDESC_SUBVIEW_24]]
434-
// AMD: %[[MEMDESC_SUBVIEW_25:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}]
435-
// AMD: ttg.local_store %[[LOAD_23]], %[[MEMDESC_SUBVIEW_25]]
436423
// AMD: %[[SUBI_26:.*]] = arith.subi %{{.*}}, %{{.*}}
437-
// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_26]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_14]], %[[ARG9:.*]] = %[[ADDPTR_15]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_11]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_24]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_12]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_25]])
424+
// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_26]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_11]], %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_12]])
438425
// AMD: %[[ADDPTR_38:.*]] = tt.addptr %[[ARG8]], %{{.*}}
439426
// AMD: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG9]], %{{.*}}
440427
// AMD: %[[LOAD_40:.*]] = tt.load %[[ADDPTR_38]]
441428
// AMD: %[[LOCAL_LOAD_41:.*]] = ttg.local_load %[[ARG11]]
442429
// AMD: %[[LOAD_42:.*]] = tt.load %[[ADDPTR_39]]
443-
// AMD: %[[MULI_43:.*]] = arith.muli %{{.*}}, %[[LOAD_42]]
430+
// AMD: %[[MULI_43:.*]] = arith.muli %{{.*}}, %[[ARG12]]
444431
// AMD: %[[SPLAT_44:.*]] = tt.splat %[[MULI_43]]
445432
// AMD: %[[ADDPTR_45:.*]] = tt.addptr %{{.*}}, %[[SPLAT_44]]
446433
// AMD: %[[LOAD_46:.*]] = tt.load %[[ADDPTR_45]]
@@ -453,7 +440,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
453440
// AMD: ttg.local_store %[[LOAD_40]], %[[MEMDESC_SUBVIEW_52]]
454441
// AMD: %[[MEMDESC_SUBVIEW_53:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_51]], %{{.*}}, %{{.*}}]
455442
// AMD: ttg.local_store %[[LOAD_46]], %[[MEMDESC_SUBVIEW_53]]
456-
// AMD: scf.yield %[[DOT_48]], %[[ADDPTR_38]], %[[ADDPTR_39]], %[[SELECT_51]], %[[ARG12]], %[[MEMDESC_SUBVIEW_52]], %[[ARG14]], %[[MEMDESC_SUBVIEW_53]]
443+
// AMD: scf.yield %[[DOT_48]], %[[ADDPTR_38]], %[[ADDPTR_39]], %[[SELECT_51]], %[[MEMDESC_SUBVIEW_52]], %[[LOAD_42]], %[[MEMDESC_SUBVIEW_53]]
457444
// AMD: } {tt.num_stages = 3
458445
// AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}}
459446
// AMD: %[[CMPI_29:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}}
@@ -466,8 +453,8 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
466453
// AMD: scf.yield %{{.*}}#0
467454
// AMD: }
468455
// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_28]], %[[IF_32]], %{{.*}}#0
469-
// AMD: %[[LOCAL_LOAD_34:.*]] = ttg.local_load %{{.*}}#5
470-
// AMD: %[[LOCAL_LOAD_35:.*]] = ttg.local_load %{{.*}}#7
456+
// AMD: %[[LOCAL_LOAD_34:.*]] = ttg.local_load %{{.*}}
457+
// AMD: %[[LOCAL_LOAD_35:.*]] = ttg.local_load %{{.*}}
471458
// AMD: %[[IF_36:.*]] = scf.if %[[CMPI_29]]
472459
// AMD: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_34]], %[[LOCAL_LOAD_35]], %[[SELECT_33]]
473460
// AMD: scf.yield %[[DOT_38]]
@@ -477,34 +464,6 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
477464
// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_29]], %[[IF_36]], %[[SELECT_33]]
478465
// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_0]]
479466
// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_1]]
480-
481-
// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar
482-
// AMD_PREFETCH: ttg.local_alloc
483-
// AMD_PREFETCH: ttg.local_alloc
484-
// AMD_PREFETCH: tt.load
485-
// AMD_PREFETCH: tt.load
486-
// AMD_PREFETCH: tt.load
487-
// AMD_PREFETCH: ttg.local_store
488-
// AMD_PREFETCH: ttg.local_store
489-
// AMD_PREFETCH: tt.load
490-
// AMD_PREFETCH: ttg.local_load
491-
// AMD_PREFETCH: tt.load
492-
// AMD_PREFETCH: tt.load
493-
// AMD_PREFETCH: ttg.local_load
494-
// AMD_PREFETCH: scf.for
495-
// AMD_PREFETCH: ttg.local_store
496-
// AMD_PREFETCH: ttg.local_store
497-
// AMD_PREFETCH: tt.dot
498-
// AMD_PREFETCH: tt.load
499-
// AMD_PREFETCH: ttg.local_load
500-
// AMD_PREFETCH: tt.load
501-
// AMD_PREFETCH: tt.load
502-
// AMD_PREFETCH: ttg.local_load
503-
// AMD_PREFETCH: scf.yield
504-
// AMD_PREFETCH: tt.dot
505-
// AMD_PREFETCH: tt.dot
506-
// AMD_PREFETCH: tt.return
507-
508467
tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
509468
%76: index,
510469
%49: tensor<16x16x!tt.ptr<f16>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},

test/TritonGPU/pipeline-lower-loop.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,3 +1634,24 @@ tt.func @load_cant_use_async_cp(%lb : index, %ub : index, %step : index,
16341634
tt.return
16351635
}
16361636
}
1637+
1638+
// -----
1639+
1640+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
1641+
// CHECK-LABEL: @scalar_load
1642+
tt.func @scalar_load(%lb : index, %ub : index, %step : index,
1643+
%a_ptr_init : !tt.ptr<i32>) -> () {
1644+
scf.for %iv = %lb to %ub step %step : index {
1645+
// CHECK: %[[PTR:.+]] = tt.splat %{{.*}} {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr<i32>
1646+
// CHECK: %[[CP:.+]] = ttg.async_copy_global_to_local %[[PTR]], %{{.+}} {loop.cluster = 0 : i32, loop.stage = 0 : i32}
1647+
// CHECK: %[[T0:.+]] = ttg.async_commit_group %[[CP]] {loop.cluster = 0 : i32, loop.stage = 0 : i32}
1648+
// CHECK: %[[T1:.+]] = ttg.async_wait %[[T0]] {loop.cluster = 1 : i32, loop.stage = 3 : i32, num = 0 : i32}
1649+
// CHECK: %[[L:.+]] = ttg.local_load %{{.+}} token %[[T1]] {loop.cluster = 1 : i32, loop.stage = 3 : i32}
1650+
// CHECK: %[[R:.+]] = tt.unsplat %[[L]] {loop.cluster = 1 : i32, loop.stage = 3 : i32}
1651+
// CHECK: "use"(%[[R]]) {loop.cluster = 1 : i32, loop.stage = 3 : i32} : (i32) -> ()
1652+
%a = tt.load %a_ptr_init {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.ptr<i32>
1653+
"use"(%a) {loop.cluster = 2 : i32, loop.stage = 3 : i32} : (i32) -> ()
1654+
} {tt.scheduled_max_stage = 3 : i32}
1655+
tt.return
1656+
}
1657+
}

0 commit comments

Comments
 (0)