Skip to content

Commit b9afb6c

Browse files
Merge commit '82e7a32179d6d3ecadac88a06916ba2b52bcfbdb'
2 parents 13725c1 + 82e7a32 commit b9afb6c

File tree

32 files changed

+799
-564
lines changed

32 files changed

+799
-564
lines changed

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ namespace triton {
2525
constexpr int patternBenefitDefault = 1;
2626
constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
28-
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2928

3029
struct BackendCallbacks {
3130
/**
@@ -50,7 +49,7 @@ void populateElementwiseOpToLLVMPatterns(
5049
// callback receives 1) the current source op, 2) the number of issued LLVM
5150
// instructions and 3) their input types. Each MLIR backend can provide a
5251
// callback and, thus, handle backend-specific behaviors.
53-
void populateMemoryOpToLLVMPattern(
52+
void populateMemoryOpToLLVMPatterns(
5453
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
5554
RewritePatternSet &patterns, PatternBenefit benefit,
5655
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
@@ -102,10 +101,6 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
102101
RewritePatternSet &patterns,
103102
PatternBenefit benefit);
104103

105-
void populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
106-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
107-
RewritePatternSet &patterns, PatternBenefit benefit);
108-
109104
void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
110105
RewritePatternSet &patterns,
111106
const TargetInfoBase &targetInfo,

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,8 +720,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
720720
auto ans = mmaLayout.getVersionMajor() == 3 &&
721721
dotOperandLayout.getOpIdx() == 0 &&
722722
mmaLayout.getWarpsPerCTA()[1] == 1 &&
723-
!cvtNeedsSharedMemory(parentTy, srcTy) &&
724-
(elementTypeSize == 16 || elementTypeSize == 8) &&
723+
!cvtNeedsSharedMemory(parentTy, srcTy) && elementTypeSize == 8 &&
725724
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
726725
return ans;
727726
}

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616

1717
namespace {
1818

19-
using ::mlir::LLVM::getMultiDimOffset;
20-
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
21-
using ::mlir::LLVM::getWrappedMultiDimOffset;
22-
using ::mlir::LLVM::linearize;
23-
19+
using namespace mlir;
2420
using namespace mlir::triton::gpu;
2521

2622
// XXX(Keren): A temporary knob to control the use of legacy MMA conversion
@@ -105,13 +101,14 @@ struct ConvertLayoutOpConversion
105101
// of performance issue observed.
106102
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
107103
SmallVector<Value> multiDimOffset =
108-
getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type,
109-
multiDimCTAInRepId, shapePerCTATile);
110-
SmallVector<Value> multiDimOffsetWrapped = getWrappedMultiDimOffset(
111-
rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile,
112-
shapePerCTA);
113-
Value offset = linearize(rewriter, loc, multiDimOffsetWrapped,
114-
paddedRepShape, outOrd);
104+
LLVM::getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId,
105+
type, multiDimCTAInRepId, shapePerCTATile);
106+
SmallVector<Value> multiDimOffsetWrapped =
107+
LLVM::getWrappedMultiDimOffset(rewriter, loc, multiDimOffset,
108+
origRepShape, shapePerCTATile,
109+
shapePerCTA);
110+
Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped,
111+
paddedRepShape, outOrd);
115112
auto elemPtrTy = smemBase.getType();
116113
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset);
117114
auto vecTy = vec_ty(llvmElemTy, vec);
@@ -267,7 +264,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
267264
// conversions. TODO(jlebar): Eventually we want this to be the only pattern.
268265
explicit ConvertLayoutOpUsingLinearLayoutsConversion(
269266
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
270-
PatternBenefit benefit = 2)
267+
PatternBenefit benefit = 1)
271268
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
272269
}
273270

@@ -395,16 +392,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
395392
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
396393
return failure();
397394
}
398-
// FIXME [Dot LL] Remove this once we implement this trick in LLs
399-
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
400-
return failure();
401-
}
402-
403-
// The following check can be removed when generalized warp shuffle
404-
// conversions are ready:
405-
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
406-
return failure();
407-
}
408395

409396
assert(cvtNeedsSharedMemory(srcTy, dstTy));
410397

@@ -666,22 +653,17 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
666653

667654
} // namespace
668655

669-
void mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
656+
void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
670657
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
671658
RewritePatternSet &patterns, PatternBenefit benefit) {
659+
if (useLegacyMMAConversion) {
660+
// Prioritize the legacy MMA conversion over the LinearLayout conversion.
661+
// Only for debugging purposes.
662+
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo,
663+
benefit.getBenefit() + 1);
664+
}
672665
patterns.add<ConvertLayoutOpUsingLinearLayoutsConversion>(
673666
typeConverter, targetInfo, benefit);
674-
}
675-
676-
void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
677-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
678-
RewritePatternSet &patterns, PatternBenefit benefit) {
679-
// We prefer using the linear layout conversion, so it gets a higher benefit.
680-
// Eventually the LL conversion will subsume all of the others and be the only
681-
// one left.
682-
mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
683-
typeConverter, targetInfo, patterns, benefit.getBenefit() + 1);
684667
patterns.add<ConvertLayoutOpBlockedToDotOpShortcutConversion>(
685668
typeConverter, targetInfo, benefit);
686-
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
687669
}

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -121,33 +121,12 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
121121

122122
// FIXME [Dot LL]
123123
// Do for all DotOperandEncodingAttr once we have LLs for all of them
124-
static bool isSupportedDotOpLayout(MemDescType srcTy,
125-
RankedTensorType dstTy) {
126-
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
127-
auto dstLayout = dstTy.getEncoding();
128-
auto bitwidth = dstTy.getElementTypeBitWidth();
129-
auto rank = dstTy.getRank();
124+
static bool isSupportedLayout(Attribute dstLayout) {
125+
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
126+
LinearEncodingAttr>(dstLayout))
127+
return true;
130128
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
131-
auto vecWidth = 32 / bitwidth;
132-
auto kWidth = dot.getKWidth();
133-
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
134-
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
135-
auto needTrans = kOrder != srcLayout.getOrder()[0];
136-
auto canUseLdmatrix =
137-
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
138-
if (mma.isHopper()) {
139-
// I think we should be able to remove this condition, but it's here
140-
// as the legacy ldmatrix path does not support it
141-
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
142-
}
143-
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
144-
// though
145-
canUseLdmatrix &=
146-
srcTy.getShape()[0] >= 8 &&
147-
srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2;
148-
return !canUseLdmatrix;
149-
}
150-
if (isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(dot.getParent()))
129+
if (isa<MmaEncodingTrait>(dot.getParent()))
151130
return true;
152131
}
153132
return false;
@@ -156,12 +135,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
156135
LogicalResult
157136
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
158137
ConversionPatternRewriter &rewriter) const override {
159-
MemDescType srcTy = op.getSrc().getType();
160138
RankedTensorType dstTy = op.getType();
161139
Attribute dstLayout = dstTy.getEncoding();
162-
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
163-
LinearEncodingAttr>(dstLayout) ||
164-
isSupportedDotOpLayout(srcTy, dstTy)) {
140+
if (isSupportedLayout(dstLayout)) {
165141
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
166142
rewriter);
167143
}
@@ -198,11 +174,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
198174
auto loc = op.getLoc();
199175
auto srcTy = op.getSrc().getType();
200176
auto dstTy = op.getResult().getType();
201-
auto dstShape = dstTy.getShape();
202-
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
203-
assert((!isa<DotOperandEncodingAttr>(dstTy.getEncoding()) ||
204-
isSupportedDotOpLayout(srcTy, dstTy)) &&
205-
"Unexpected rank of ConvertLayout(shared->distributed)");
206177

207178
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
208179
loc, adaptor.getSrc(),
@@ -265,7 +236,7 @@ struct LocalStoreOpConversion
265236

266237
} // namespace
267238

268-
void mlir::triton::populateMemoryOpToLLVMPattern(
239+
void mlir::triton::populateMemoryOpToLLVMPatterns(
269240
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
270241
RewritePatternSet &patterns, PatternBenefit benefit,
271242
std::optional<BackendCallbacks> backendCallbacks) {

lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ namespace {
1717
// dot(a, b, inputPrecision="tf32x3") ->
1818
// let aBig = f32ToTF32(a), aSmall = a - aBig;
1919
// let bBig = f32ToTF32(b), bSmall = b - bBig;
20-
// dot(aSmall, bBig, inputPrecision="tf32") +
21-
// dot(aBig, bSmall, inputPrecision="tf32") +
22-
// dot(aBig, bBig, inputPrecision="tf32")
20+
// let small = dot(aSmall, bBig, inputPrecision="tf32") +
21+
// dot(aBig, bSmall, inputPrecision="tf32")
22+
// let masked_nans = replaceNansWithZeros(small)
23+
// let big = dot(aBig, bBig, inputPrecision="tf32")
24+
// return big + masked_nans;
2325
class TF32x3 : public OpRewritePattern<DotOp> {
2426
public:
2527
using OpRewritePattern::OpRewritePattern;
@@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern<DotOp> {
6264
InputPrecision::TF32,
6365
dotOp.getMaxNumImpreciseAcc());
6466
};
67+
auto replaceNansWithZeros = [&](Value value) -> Value {
68+
auto nans = rewriter.create<arith::CmpFOp>(
69+
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
70+
auto zero = zeroLike(value);
71+
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
72+
value);
73+
};
6574

6675
auto aBig = f32ToTF32(dotOp.getA());
6776
auto aSmall = sub(dotOp.getA(), aBig);
@@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern<DotOp> {
7382

7483
auto dot1 = dot(aSmall, bBig, zero);
7584
auto dot2 = dot(aBig, bSmall, dot1);
76-
auto dot3 = dot(aBig, bBig, dot2);
85+
86+
// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
87+
// If rhs is +infinity, we will have:
88+
// +infinity * 1.0 = +infinity
89+
// +infinity * 0.0 = NaN
90+
// We would get the wrong result if we sum these partial products. Instead,
91+
// we must override any accumulated result if the last partial product is
92+
// non-finite.
93+
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
94+
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);
7795

7896
auto sum = add(dot3, dotOp.getC());
7997

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,10 @@ struct MMAV3UseRegOperand
399399
dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/kWidth);
400400
auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(),
401401
dotOperandEnc);
402-
if (!matchMmaV3AndDotOperandLayout(srcTy, newTy))
402+
// TODO(Keren): relax the condition once
403+
// https://github.com/triton-lang/triton/pull/5419 is merged
404+
if (!cvtReordersRegisters(srcTy, newTy) &&
405+
!matchMmaV3AndDotOperandLayout(srcTy, newTy))
403406
return failure();
404407

405408
Value newOperand =

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
103103
return op;
104104
}
105105

106-
assert("don't know how to predicate this op" && false);
106+
op->emitError("pipeliner doesn't know how to predicate this op.");
107+
llvm::report_fatal_error("Fatal pipeliner error");
107108
return op;
108109
}
109110

python/test/unit/language/test_core.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,48 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
16901690
assert (torch.equal(X, Y))
16911691

16921692

1693+
@pytest.mark.interpreter
1694+
@pytest.mark.skipif((is_cuda() and torch.cuda.get_device_capability()[0] < 9) or is_hip(),
1695+
reason="Requires compute capability >= 9 for NV")
1696+
def test_load_scope_sem_coop_grid_cta_not_one(device):
1697+
1698+
@triton.jit
1699+
def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
1700+
numel = 512
1701+
offset = tl.program_id(0) * BLOCK_SIZE
1702+
index = offset
1703+
mask = index < numel
1704+
a = tl.load(ptrs, mask=mask)
1705+
tl.store(ptrs, a)
1706+
1707+
block_size = 128
1708+
data = torch.zeros((128, ), device=device, dtype=torch.float32)
1709+
1710+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=True)
1711+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=False)
1712+
1713+
1714+
@pytest.mark.interpreter
1715+
@pytest.mark.skipif(is_hip(), reason="Not implemented for AMD At this moment")
1716+
def test_load_scope_sem_coop_grid_cta_one(device):
1717+
1718+
@triton.jit
1719+
def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
1720+
numel = 512
1721+
offset = tl.program_id(0) * BLOCK_SIZE
1722+
index = offset
1723+
mask = index < numel
1724+
a = tl.load(ptrs, mask=mask)
1725+
tl.store(ptrs, a)
1726+
1727+
block_size = 128
1728+
data = torch.zeros((128, ), device=device, dtype=torch.float32)
1729+
1730+
# Should do nothing different for num_ctas=1 (with coop launch grid)
1731+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=True)
1732+
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False)
1733+
1734+
16931735
# ---------------
16941736
# test cast
16951737
# ---------------

python/triton/runtime/autotuner.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import time
66
import inspect
7-
from typing import Dict
7+
from typing import Dict, Tuple, List, Optional
88

99
from .jit import KernelInterface
1010
from .errors import OutOfResources, PTXASError
@@ -23,7 +23,7 @@ def __init__(
2323
restore_value,
2424
pre_hook=None,
2525
post_hook=None,
26-
prune_configs_by: Dict = None,
26+
prune_configs_by: Optional[Dict] = None,
2727
warmup=None,
2828
rep=None,
2929
use_cuda_graph=False,
@@ -40,7 +40,7 @@ def __init__(
4040
else:
4141
self.configs = configs
4242
self.keys = key
43-
self.cache = {}
43+
self.cache: Dict[Tuple, Config] = {}
4444
self.arg_names = arg_names
4545

4646
# Reset to zero or restore values
@@ -211,14 +211,18 @@ def run(self, *args, **kwargs):
211211
self.nargs = None
212212
return ret
213213

214-
def prune_configs(self, kwargs):
214+
def prune_configs(self, kwargs: Dict) -> List[Config]:
215215
pruned_configs = self.configs
216216
if self.early_config_prune:
217217
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
218218
if self.perf_model:
219219
top_k = self.configs_top_k
220220
if isinstance(top_k, float) and top_k <= 1.0:
221221
top_k = int(len(self.configs) * top_k)
222+
elif not isinstance(top_k, int):
223+
# Slice index must be an integer
224+
raise TypeError(f"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int")
225+
222226
if len(pruned_configs) > top_k:
223227
est_timing = {
224228
config: self.perf_model(

python/triton/runtime/jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def _call_hook(
501501
name = self.fn.__name__
502502
module = self.fn.__module__
503503
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
504-
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
504+
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
505505

506506
class JitFunctionInfo:
507507

@@ -521,6 +521,7 @@ def __init__(self, module, name, jit_function):
521521
'num_ctas': options.num_ctas,
522522
'num_stages': options.num_stages,
523523
'enable_fp_fusion': options.enable_fp_fusion,
524+
'launch_cooperative_grid': options.launch_cooperative_grid,
524525
'extern_libs': options.extern_libs,
525526
'configs': configs,
526527
'specialization_data': specialization_data,

0 commit comments

Comments
 (0)