Skip to content

Commit 3ffaa2f

Browse files
Merge commit 'd997364bd617ba91911ecd73070f57f291611203'
2 parents d16ef81 + d997364 commit 3ffaa2f

File tree

7 files changed

+165
-131
lines changed

7 files changed

+165
-131
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ name: Integration Tests
1010
on:
1111
workflow_dispatch:
1212
pull_request:
13-
# You can name your branch dev-foo to get CI runs.
14-
branches: [main, 'dev-**']
13+
branches-ignore: ['llvm-**']
1514
merge_group:
1615
branches: [main, 'dev-**']
1716
types: [checks_requested]

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ name: Integration Tests
99
on:
1010
workflow_dispatch:
1111
pull_request:
12-
# You can name your branch dev-foo to get CI runs.
13-
branches: [main, 'dev-**']
12+
branches-ignore: ['llvm-**']
1413
merge_group:
1514
branches: [main, 'dev-**']
1615
types: [checks_requested]

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -285,19 +285,6 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
285285
Location loc = forOp.getLoc();
286286
SmallVector<Value> predicates(maxStage);
287287
for (int64_t i = 0; i < maxStage; i++) {
288-
if (dynamicLoop) {
289-
Type t = ub.getType();
290-
// pred = ub > lb + (i * step)
291-
Value iv = rewriter.create<arith::AddIOp>(
292-
loc, lb,
293-
rewriter.create<arith::MulIOp>(
294-
loc, step,
295-
rewriter.create<arith::ConstantOp>(
296-
loc, rewriter.getIntegerAttr(t, i))));
297-
predicates[i] = rewriter.create<arith::CmpIOp>(
298-
loc, arith::CmpIPredicate::slt, iv, ub);
299-
}
300-
301288
// special handling for induction variable as the increment is implicit.
302289
// iv = lb + i * step
303290
Type t = lb.getType();
@@ -308,6 +295,13 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
308295
rewriter.create<arith::ConstantOp>(loc,
309296
rewriter.getIntegerAttr(t, i))));
310297
setValueMapping(forOp.getInductionVar(), iv, i);
298+
299+
if (dynamicLoop) {
300+
// pred = ub > lb + (i * step)
301+
predicates[i] = rewriter.create<arith::CmpIOp>(
302+
loc, arith::CmpIPredicate::slt, iv, ub);
303+
}
304+
311305
for (Operation *op : opOrder) {
312306
if (stages[op] > i)
313307
continue;
@@ -655,50 +649,56 @@ LogicalResult
655649
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
656650
llvm::SmallVector<Value> &returnValues) {
657651
Location loc = forOp.getLoc();
652+
Type t = lb.getType();
658653
// Emit different versions of the induction variable. They will be
659654
// removed by dead code if not used.
660655

661-
// range_diff = ub - lb
662-
// total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
663-
Type t = lb.getType();
664-
Value zero =
665-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
666-
Value one =
667-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
668-
Value minusOne =
669-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
656+
auto createConst = [&](int v) {
657+
return rewriter.create<arith::ConstantOp>(loc,
658+
rewriter.getIntegerAttr(t, v));
659+
};
660+
661+
// total_iterations = cdiv(range_diff, step);
662+
// - range_diff = ub - lb
663+
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
664+
Value zero = createConst(0);
665+
Value one = createConst(1);
670666
Value stepLessZero = rewriter.create<arith::CmpIOp>(
671667
loc, arith::CmpIPredicate::slt, step, zero);
672668
Value stepDecr =
673-
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
669+
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
674670

675671
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
676672
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
677673
Value rangeDecr =
678674
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
679675
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
680676

677+
// If total_iters < max_stage, start the epilogue at zero to match the
678+
// ramp-up in the prologue.
679+
// start_iter = max(0, total_iters - max_stage)
680+
Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
681+
createConst(maxStage));
682+
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
683+
681684
// Capture predicates for dynamic loops.
682685
SmallVector<Value> predicates(maxStage + 1);
683686

684-
for (int64_t i = 0; i < maxStage; i++) {
685-
// iterI = total_iters - 1 - i
686-
// May go negative...
687-
Value minusI =
688-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
689-
Value iterI = rewriter.create<arith::AddIOp>(
690-
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
691-
minusI);
687+
for (int64_t i = 1; i <= maxStage; i++) {
692688
// newLastIter = lb + step * iterI
693689
Value newlastIter = rewriter.create<arith::AddIOp>(
694690
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
695691

696-
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
692+
setValueMapping(forOp.getInductionVar(), newlastIter, i);
693+
694+
// increment to next iterI
695+
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
697696

698697
if (dynamicLoop) {
699-
// pred = iterI >= 0
700-
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
701-
loc, arith::CmpIPredicate::sge, iterI, zero);
698+
// Disable stages when `i` is greater than total_iters.
699+
// pred = total_iters >= i
700+
predicates[i] = rewriter.create<arith::CmpIOp>(
701+
loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
702702
}
703703
}
704704

python/test/unit/language/test_pipeliner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,32 @@ def test_pipeline_vecadd(device):
180180
assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match"
181181
# 3. check alloc
182182
assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match"
183+
184+
185+
@pytest.mark.parametrize("ROW_COUNT", [0, 1, 2, 3])
186+
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 3, 4, 5])
187+
def test_pipeline_epilogue(ROW_COUNT, NUM_STAGES, device):
188+
189+
@triton.jit
190+
def kernel_up(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
191+
NUM_STAGES: tl.constexpr):
192+
row_step = tl.num_programs(0)
193+
col_offsets = tl.arange(0, BLOCK_SIZE)
194+
mask = col_offsets < n_cols
195+
for row_idx in tl.range(0, n_rows, row_step, num_stages=NUM_STAGES):
196+
row_start_ptr = input_ptr + row_idx * input_row_stride
197+
input_ptrs = row_start_ptr + col_offsets
198+
val = tl.load(input_ptrs, mask=mask, other=-float('inf'))
199+
val += 1.0
200+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
201+
output_ptrs = output_row_start_ptr + col_offsets
202+
tl.store(output_ptrs, val, mask=mask)
203+
204+
width = ROW_COUNT
205+
depth = 78
206+
x = torch.zeros(width, depth, device=device)
207+
y0 = torch.rand_like(x)
208+
n_rows, n_cols = x.shape
209+
BLOCK_SIZE = triton.next_power_of_2(n_cols)
210+
kernel_up[(1, )](y0, x, x.stride(0), y0.stride(0), n_rows, n_cols, BLOCK_SIZE, NUM_STAGES)
211+
assert (y0 == torch.ones_like(x)).all()

python/triton/language/_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import List
2+
3+
TRITON_MAX_TENSOR_NUMEL = 1048576
4+
5+
6+
def is_power_of_two(x):
7+
return (x & (x - 1)) == 0
8+
9+
10+
def validate_block_shape(shape: List[int]):
11+
numel = 1
12+
for i, d in enumerate(shape):
13+
if not isinstance(d, int):
14+
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
15+
if not is_power_of_two(d):
16+
raise ValueError(f"Shape element {i} must be a power of 2")
17+
numel *= d
18+
19+
if numel > TRITON_MAX_TENSOR_NUMEL:
20+
raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
21+
return numel

python/triton/language/core.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313

1414
from .._C.libtriton import ir
1515
from . import semantic
16+
from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
1617

1718
T = TypeVar('T')
1819

19-
TRITON_MAX_TENSOR_NUMEL = 1048576
20-
2120
TRITON_BUILTIN = "__triton_builtin__"
2221

2322
PropagateNan = ir.PROPAGATE_NAN
@@ -612,18 +611,11 @@ def __init__(self, element_ty: dtype, shape: List):
612611
# while tensor's shape is a list of constexpr.
613612

614613
# shape can be empty ([]) when an input is a 0D tensor.
615-
if not shape:
614+
self.shape = _unwrap_shape(shape)
615+
if not self.shape:
616616
raise TypeError('0d block_type is forbidden')
617-
if isinstance(shape[0], constexpr):
618-
shape = [s.value for s in shape]
619-
620-
self.shape = shape
621-
self.numel = 1
622-
for s in self.shape:
623-
self.numel *= s
624-
if self.numel > TRITON_MAX_TENSOR_NUMEL:
625-
raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
626617

618+
self.numel = validate_block_shape(self.shape)
627619
self.name = f'<{self.shape}, {self.element_ty}>'
628620

629621
def to_ir(self, builder: ir.builder) -> ir.block_type:
@@ -1208,18 +1200,15 @@ def arange(start, end, _builder=None):
12081200
"""
12091201

12101202

1211-
def _shape_check_impl(shape):
1203+
def _unwrap_shape(shape):
12121204
shape = _constexpr_to_value(shape)
1213-
for i, d in enumerate(shape):
1214-
if isinstance(d, int):
1215-
d = constexpr(d)
1216-
if not isinstance(d, constexpr):
1217-
raise TypeError(f"Shape element {i} must have type `constexpr`")
1218-
if not isinstance(d.value, int):
1219-
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
1220-
if d.value & (d.value - 1) != 0:
1221-
raise ValueError(f"Shape element {i} must be a power of 2")
1222-
return [_constexpr_to_value(x) for x in shape]
1205+
return [_constexpr_to_value(s) for s in shape]
1206+
1207+
1208+
def _shape_check_impl(shape):
1209+
shape = _unwrap_shape(shape)
1210+
validate_block_shape(shape)
1211+
return shape
12231212

12241213

12251214
@builtin

0 commit comments

Comments
 (0)