Skip to content

Commit be627a3

Browse files
Merge commit 'd997364bd617ba91911ecd73070f57f291611203'
2 parents 3b58b25 + d997364 commit be627a3

File tree

13 files changed

+335
-148
lines changed

13 files changed

+335
-148
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ name: Integration Tests
1010
on:
1111
workflow_dispatch:
1212
pull_request:
13-
# You can name your branch dev-foo to get CI runs.
14-
branches: [main, 'dev-**']
13+
branches-ignore: ['llvm-**']
1514
merge_group:
1615
branches: [main, 'dev-**']
1716
types: [checks_requested]

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ name: Integration Tests
99
on:
1010
workflow_dispatch:
1111
pull_request:
12-
# You can name your branch dev-foo to get CI runs.
13-
branches: [main, 'dev-**']
12+
branches-ignore: ['llvm-**']
1413
merge_group:
1514
branches: [main, 'dev-**']
1615
types: [checks_requested]

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
8888
// encoding not available
8989
return resultVals;
9090
Attribute baseEncoding = encoding;
91-
if (isa<AMDMfmaEncodingAttr>(baseEncoding))
92-
// TODO: this logic seems incorrect for mfma layout. Skip for now.
93-
// We saw mismatches for some flash-attention tests on AMD backend.
94-
// Note that this logic works for sliced layout whose parent is
91+
if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
92+
isa<AMDWmmaEncodingAttr>(baseEncoding))
93+
// TODO: this logic seems incorrect for mfma and wmma layout. Skip for
94+
// now. We saw mismatches for some flash-attention and dot tests on AMD
95+
// backend. Note that this logic works for sliced layout whose parent is
9596
// mfma layout. Therefore, this is not combined with the following check.
9697
return resultVals;
9798
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
250250
ArrayRef<unsigned> repShape,
251251
ArrayRef<unsigned> paddedRepShape,
252252
ArrayRef<unsigned> order, int swizzleByteSize);
253+
254+
// FIXME
255+
// Exposing to use it in LinearLayoutConversionsTest.cpp
256+
// Remove it once we fully activate the DotOperand conversion via LLs
257+
class DotOperandEncodingAttr;
258+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
259+
DotOperandEncodingAttr dot);
253260
} // namespace mlir::triton::gpu
254261

255262
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,16 +1044,12 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10441044
return res;
10451045
}
10461046
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
1047-
auto parentLayout = getParent();
1048-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1049-
if (auto distributedLayout =
1050-
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
1051-
return distributedLayout.getWarpsPerCTA();
1052-
} else {
1053-
llvm::report_fatal_error(
1054-
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
1055-
"supported yet");
1056-
}
1047+
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1048+
auto warps = distributedLayout.getWarpsPerCTA();
1049+
auto rank = warps.size();
1050+
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1051+
warps[kDim] = 1;
1052+
return warps;
10571053
}
10581054
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10591055
return ::getWarpOrder(*this);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
89
#include "triton/Tools/LinearLayout.h"
910
#include "triton/Tools/StrUtil.h"
1011
#include "llvm/ADT/DenseMap.h"
@@ -822,16 +823,82 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
822823
return ret;
823824
}
824825

826+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
827+
DotOperandEncodingAttr dot) {
828+
// TODO,BE. Implement ampereMMA in terms of this one
829+
int rank = shape.size();
830+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
831+
int kWidth = dot.getKWidth();
832+
bool isA = dot.getOpIdx() == 0;
833+
834+
assert(mma.isAmpere());
835+
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
836+
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
837+
838+
MLIRContext *ctx = mma.getContext();
839+
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
840+
841+
// Implement A. For B transpose in the end
842+
std::vector<std::vector<int32_t>> registers;
843+
std::vector<std::vector<int32_t>> lanes;
844+
int32_t i = 1;
845+
// kWidth contiguous elements
846+
while (i < kWidth) {
847+
registers.push_back({i, 0});
848+
i *= 2;
849+
}
850+
// 4 threads per chunk
851+
for (int j = 0; j < 2; j++) {
852+
lanes.push_back({i, 0});
853+
i *= 2;
854+
}
855+
// 8 threads going down
856+
lanes.push_back({0, 1});
857+
lanes.push_back({0, 2});
858+
lanes.push_back({0, 4});
859+
// 2 tiles in column-major order
860+
// Just one if it's the B operand
861+
if (isA) {
862+
registers.push_back({0, 8});
863+
}
864+
registers.push_back({i, 0});
865+
866+
if (!isA) {
867+
for (auto &r : registers) {
868+
std::swap(r[0], r[1]);
869+
}
870+
for (auto &l : lanes) {
871+
std::swap(l[0], l[1]);
872+
}
873+
}
874+
875+
LinearLayout ctaLayout(
876+
{{S("register"), registers}, {S("lane"), lanes}},
877+
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
878+
879+
auto order = dot.getCTAOrder();
880+
assert(order[0] == 1 && order[1] == 0);
881+
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
882+
883+
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
884+
}
885+
825886
std::optional<LinearLayout>
826887
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827-
828888
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
829889
return dotOperandMfmaToLinearLayout(*this, shape);
830890
}
891+
892+
// TODO Activate in a follow-up PR
893+
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
894+
// if (mma.isAmpere()) {
895+
// return ampereDotToLinearLayout(shape, *this);
896+
// }
897+
//}
898+
831899
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
832900
return dotOperandDpasToLinearLayout(*this, shape);
833901
}
834-
835902
return std::nullopt;
836903
}
837904

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -285,19 +285,6 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
285285
Location loc = forOp.getLoc();
286286
SmallVector<Value> predicates(maxStage);
287287
for (int64_t i = 0; i < maxStage; i++) {
288-
if (dynamicLoop) {
289-
Type t = ub.getType();
290-
// pred = ub > lb + (i * step)
291-
Value iv = rewriter.create<arith::AddIOp>(
292-
loc, lb,
293-
rewriter.create<arith::MulIOp>(
294-
loc, step,
295-
rewriter.create<arith::ConstantOp>(
296-
loc, rewriter.getIntegerAttr(t, i))));
297-
predicates[i] = rewriter.create<arith::CmpIOp>(
298-
loc, arith::CmpIPredicate::slt, iv, ub);
299-
}
300-
301288
// special handling for induction variable as the increment is implicit.
302289
// iv = lb + i * step
303290
Type t = lb.getType();
@@ -308,6 +295,13 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
308295
rewriter.create<arith::ConstantOp>(loc,
309296
rewriter.getIntegerAttr(t, i))));
310297
setValueMapping(forOp.getInductionVar(), iv, i);
298+
299+
if (dynamicLoop) {
300+
// pred = ub > lb + (i * step)
301+
predicates[i] = rewriter.create<arith::CmpIOp>(
302+
loc, arith::CmpIPredicate::slt, iv, ub);
303+
}
304+
311305
for (Operation *op : opOrder) {
312306
if (stages[op] > i)
313307
continue;
@@ -655,50 +649,56 @@ LogicalResult
655649
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
656650
llvm::SmallVector<Value> &returnValues) {
657651
Location loc = forOp.getLoc();
652+
Type t = lb.getType();
658653
// Emit different versions of the induction variable. They will be
659654
// removed by dead code if not used.
660655

661-
// range_diff = ub - lb
662-
// total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
663-
Type t = lb.getType();
664-
Value zero =
665-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
666-
Value one =
667-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
668-
Value minusOne =
669-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
656+
auto createConst = [&](int v) {
657+
return rewriter.create<arith::ConstantOp>(loc,
658+
rewriter.getIntegerAttr(t, v));
659+
};
660+
661+
// total_iterations = cdiv(range_diff, step);
662+
// - range_diff = ub - lb
663+
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
664+
Value zero = createConst(0);
665+
Value one = createConst(1);
670666
Value stepLessZero = rewriter.create<arith::CmpIOp>(
671667
loc, arith::CmpIPredicate::slt, step, zero);
672668
Value stepDecr =
673-
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
669+
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
674670

675671
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
676672
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
677673
Value rangeDecr =
678674
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
679675
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
680676

677+
// If total_iters < max_stage, start the epilogue at zero to match the
678+
// ramp-up in the prologue.
679+
// start_iter = max(0, total_iters - max_stage)
680+
Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
681+
createConst(maxStage));
682+
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
683+
681684
// Capture predicates for dynamic loops.
682685
SmallVector<Value> predicates(maxStage + 1);
683686

684-
for (int64_t i = 0; i < maxStage; i++) {
685-
// iterI = total_iters - 1 - i
686-
// May go negative...
687-
Value minusI =
688-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
689-
Value iterI = rewriter.create<arith::AddIOp>(
690-
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
691-
minusI);
687+
for (int64_t i = 1; i <= maxStage; i++) {
692688
// newLastIter = lb + step * iterI
693689
Value newlastIter = rewriter.create<arith::AddIOp>(
694690
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
695691

696-
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
692+
setValueMapping(forOp.getInductionVar(), newlastIter, i);
693+
694+
// increment to next iterI
695+
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
697696

698697
if (dynamicLoop) {
699-
// pred = iterI >= 0
700-
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
701-
loc, arith::CmpIPredicate::sge, iterI, zero);
698+
// Disable stages when `i` is greater than total_iters.
699+
// pred = total_iters >= i
700+
predicates[i] = rewriter.create<arith::CmpIOp>(
701+
loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
702702
}
703703
}
704704

python/test/unit/language/test_pipeliner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,32 @@ def test_pipeline_vecadd(device):
180180
assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match"
181181
# 3. check alloc
182182
assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match"
183+
184+
185+
@pytest.mark.parametrize("ROW_COUNT", [0, 1, 2, 3])
186+
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 3, 4, 5])
187+
def test_pipeline_epilogue(ROW_COUNT, NUM_STAGES, device):
188+
189+
@triton.jit
190+
def kernel_up(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
191+
NUM_STAGES: tl.constexpr):
192+
row_step = tl.num_programs(0)
193+
col_offsets = tl.arange(0, BLOCK_SIZE)
194+
mask = col_offsets < n_cols
195+
for row_idx in tl.range(0, n_rows, row_step, num_stages=NUM_STAGES):
196+
row_start_ptr = input_ptr + row_idx * input_row_stride
197+
input_ptrs = row_start_ptr + col_offsets
198+
val = tl.load(input_ptrs, mask=mask, other=-float('inf'))
199+
val += 1.0
200+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
201+
output_ptrs = output_row_start_ptr + col_offsets
202+
tl.store(output_ptrs, val, mask=mask)
203+
204+
width = ROW_COUNT
205+
depth = 78
206+
x = torch.zeros(width, depth, device='cuda')
207+
y0 = torch.rand_like(x)
208+
n_rows, n_cols = x.shape
209+
BLOCK_SIZE = triton.next_power_of_2(n_cols)
210+
kernel_up[(1, )](y0, x, x.stride(0), y0.stride(0), n_rows, n_cols, BLOCK_SIZE, NUM_STAGES)
211+
assert (y0 == torch.ones_like(x)).all()

python/triton/language/_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import List
2+
3+
TRITON_MAX_TENSOR_NUMEL = 1048576
4+
5+
6+
def is_power_of_two(x):
7+
return (x & (x - 1)) == 0
8+
9+
10+
def validate_block_shape(shape: List[int]):
11+
numel = 1
12+
for i, d in enumerate(shape):
13+
if not isinstance(d, int):
14+
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
15+
if not is_power_of_two(d):
16+
raise ValueError(f"Shape element {i} must be a power of 2")
17+
numel *= d
18+
19+
if numel > TRITON_MAX_TENSOR_NUMEL:
20+
raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
21+
return numel

python/triton/language/core.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313

1414
from .._C.libtriton import ir
1515
from . import semantic
16+
from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
1617

1718
T = TypeVar('T')
1819

19-
TRITON_MAX_TENSOR_NUMEL = 1048576
20-
2120
TRITON_BUILTIN = "__triton_builtin__"
2221

2322
PropagateNan = ir.PROPAGATE_NAN
@@ -612,18 +611,11 @@ def __init__(self, element_ty: dtype, shape: List):
612611
# while tensor's shape is a list of constexpr.
613612

614613
# shape can be empty ([]) when an input is a 0D tensor.
615-
if not shape:
614+
self.shape = _unwrap_shape(shape)
615+
if not self.shape:
616616
raise TypeError('0d block_type is forbidden')
617-
if isinstance(shape[0], constexpr):
618-
shape = [s.value for s in shape]
619-
620-
self.shape = shape
621-
self.numel = 1
622-
for s in self.shape:
623-
self.numel *= s
624-
if self.numel > TRITON_MAX_TENSOR_NUMEL:
625-
raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
626617

618+
self.numel = validate_block_shape(self.shape)
627619
self.name = f'<{self.shape}, {self.element_ty}>'
628620

629621
def to_ir(self, builder: ir.builder) -> ir.block_type:
@@ -1208,18 +1200,15 @@ def arange(start, end, _builder=None):
12081200
"""
12091201

12101202

1211-
def _shape_check_impl(shape):
1203+
def _unwrap_shape(shape):
12121204
shape = _constexpr_to_value(shape)
1213-
for i, d in enumerate(shape):
1214-
if isinstance(d, int):
1215-
d = constexpr(d)
1216-
if not isinstance(d, constexpr):
1217-
raise TypeError(f"Shape element {i} must have type `constexpr`")
1218-
if not isinstance(d.value, int):
1219-
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
1220-
if d.value & (d.value - 1) != 0:
1221-
raise ValueError(f"Shape element {i} must be a power of 2")
1222-
return [_constexpr_to_value(x) for x in shape]
1205+
return [_constexpr_to_value(s) for s in shape]
1206+
1207+
1208+
def _shape_check_impl(shape):
1209+
shape = _unwrap_shape(shape)
1210+
validate_block_shape(shape)
1211+
return shape
12231212

12241213

12251214
@builtin

0 commit comments

Comments
 (0)