Skip to content

Commit 9ef5016

Browse files
dbd64Daniel Donenfeld
andauthored
Fix for compiler crash when analyzing loop scalar mask (#314)
It was previously assumed that end start and end would be set when analyzing the mask state in a loop. When the mask state is set to scalar, compute the start and end for use in the rest of the analysis. --------- Co-authored-by: Daniel Donenfeld <[email protected]>
1 parent 2b24609 commit 9ef5016

File tree

2 files changed

+62
-11
lines changed

2 files changed

+62
-11
lines changed

lib/Analysis/MaskAnalysis.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ LogicalResult MaskState::minStateScalar(const MaskState &lhsState,
235235
}
236236
} else {
237237
InFlightDiagnostic diag =
238-
emitError(loc) << "Unexpected case where both lhs and rhs are not scalars";
238+
emitError(loc)
239+
<< "Unexpected case where both lhs and rhs are not scalars";
239240
return failure();
240241
}
241242
return success();
@@ -329,7 +330,7 @@ LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc,
329330
if (failed(rhsState.parse(andOp.getRhs(), loc, builder)))
330331
return failure();
331332

332-
if(!lhsState.isMask() || !rhsState.isMask()) {
333+
if (!lhsState.isMask() || !rhsState.isMask()) {
333334
return this->minStateScalar(lhsState, rhsState, loc, builder);
334335
}
335336
return this->minStates(lhsState, rhsState, loc, builder);
@@ -363,8 +364,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
363364
// We only support sge against 0 for lower bounds. Dims already has an
364365
// implicit assumption that the lower bound is 0, so if we see this, assume
365366
// the comparison evaluates to true.
366-
if (cmpOp.getPredicate() == arith::CmpIPredicate::sge
367-
&& !(rhsState.scalar && hasConstZero(rhsState.scalar))) {
367+
if (cmpOp.getPredicate() == arith::CmpIPredicate::sge &&
368+
!(rhsState.scalar && hasConstZero(rhsState.scalar))) {
368369
InFlightDiagnostic diag = emitError(loc)
369370
<< "Unsupported cmpi with rhs not equal to 0";
370371
return failure();
@@ -383,8 +384,11 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
383384
cmpDim = i;
384385
}
385386
}
386-
assert(cmpDim != -1 &&
387-
"Unexpected case where no dimension has size larger than 1");
387+
assert(
388+
cmpDim != -1 ||
389+
(!lhsState.scalar && cmpOp.getPredicate() == arith::CmpIPredicate::slt ||
390+
cmpOp.getPredicate() == arith::CmpIPredicate::ult) &&
391+
"Unexpected case where no dimension has size larger than 1");
388392

389393
OpFoldResult newDim;
390394
if (lhsState.scalar) {
@@ -397,10 +401,10 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
397401
// should be loaded/stored by inserting a comparison + select:
398402
// dim = lhs < rhs ? lhs.dim : 0
399403
newDim = compareOFRs(lhsState.scalar, rhsState.scalar, cmpOp.getPredicate(),
400-
lhsState.dims[cmpDim], builder.getIndexAttr(0),
401-
loc, builder);
404+
lhsState.dims[cmpDim], builder.getIndexAttr(0), loc,
405+
builder);
402406
} else if (cmpOp.getPredicate() == arith::CmpIPredicate::slt ||
403-
cmpOp.getPredicate() == arith::CmpIPredicate::ult) {
407+
cmpOp.getPredicate() == arith::CmpIPredicate::ult) {
404408
// Important:
405409
// In the case where the values we are loading are entirely masked off like
406410
// the following:
@@ -418,8 +422,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
418422
newEnd = maxOFRs(newEnd, lhsState.start, loc, builder);
419423
newDim = subOFRs(newEnd, lhsState.start, loc, builder);
420424
} else {
421-
assert(cmpOp.getPredicate() == arith::CmpIPredicate::sge && rhsState.scalar
422-
&& hasConstZero(rhsState.scalar));
425+
assert(cmpOp.getPredicate() == arith::CmpIPredicate::sge &&
426+
rhsState.scalar && hasConstZero(rhsState.scalar));
423427
newDim = lhsState.dims[cmpDim];
424428
}
425429

@@ -507,6 +511,12 @@ LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc,
507511
}
508512
}
509513

514+
if (!lhsState.start && !lhsState.end) {
515+
assert(lhsState.scalar && "MaskState must have a scalar");
516+
lhsState.start = builder.getIndexAttr(0);
517+
lhsState.end = lhsState.scalar;
518+
}
519+
510520
auto dist = subOFRs(lhsState.end, lhsState.start, loc, builder);
511521
this->start = forOp.getRegionIterArg(argIndex + 1);
512522
this->end = addOFRs(this->start, dist, loc, builder);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s
2+
3+
module {
4+
tt.func public @scalar_mask_loop(%arg0: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
5+
%cst = arith.constant dense<4.480000e+02> : tensor<1xf32>
6+
%cst_0 = arith.constant dense<-4.480000e+02> : tensor<1xf32>
7+
%c1_i32 = arith.constant 1 : i32
8+
%c0_i32 = arith.constant 0 : i32
9+
%0 = tt.get_program_id x : i32
10+
%1 = tt.get_num_programs x : i32
11+
%2 = arith.addi %arg3, %1 : i32
12+
%3 = arith.subi %2, %c1_i32 : i32
13+
%4 = arith.divsi %3, %1 : i32
14+
%5 = tt.load %arg2 : !tt.ptr<f32>
15+
%6 = tt.splat %0 : i32 -> tensor<1xi32>
16+
%7 = scf.for %arg4 = %c0_i32 to %4 step %c1_i32 iter_args(%arg5 = %6) -> (tensor<1xi32>) : i32 {
17+
%8 = tt.splat %arg3 : i32 -> tensor<1xi32>
18+
%9 = arith.cmpi slt, %arg5, %8 : tensor<1xi32>
19+
%10 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<1x!tt.ptr<bf16>>
20+
%11 = tt.addptr %10, %arg5 : tensor<1x!tt.ptr<bf16>>, tensor<1xi32>
21+
%12 = tt.load %11, %9 : tensor<1x!tt.ptr<bf16>>
22+
%13 = arith.extf %12 : tensor<1xbf16> to tensor<1xf32>
23+
%14 = tt.splat %5 : f32 -> tensor<1xf32>
24+
%15 = arith.mulf %13, %14 : tensor<1xf32>
25+
%16 = tt.clampf %15, %cst_0, %cst, propagateNan = none : tensor<1xf32>
26+
%17 = tt.fp_to_fp %16, rounding = rtne : tensor<1xf32> -> tensor<1xf8E4M3FN>
27+
%18 = tt.splat %arg0 : !tt.ptr<f8E4M3FN> -> tensor<1x!tt.ptr<f8E4M3FN>>
28+
%19 = tt.addptr %18, %arg5 : tensor<1x!tt.ptr<f8E4M3FN>>, tensor<1xi32>
29+
tt.store %19, %17, %9 : tensor<1x!tt.ptr<f8E4M3FN>>
30+
%20 = tt.splat %1 : i32 -> tensor<1xi32>
31+
%21 = arith.addi %arg5, %20 : tensor<1xi32>
32+
scf.yield %21 : tensor<1xi32>
33+
}
34+
tt.return
35+
}
36+
}
37+
38+
39+
// CHECK: %8 = scf.for %arg4 = %c0_i32 to %6 step %c1_i32 iter_args(%arg5 = %1) -> (index) : i32 {
40+
// CHECK: %9 = tts.make_tptr %arg1 to sizes: [1], strides: [%c0], offsets: [%arg5], shape: [0], order: [] : <bf16> to tensor<1x!tt.ptr<bf16>>
41+
// CHECK: %10 = "tts.load"(%9) <{operandSegmentSizes = array<i32: 1, 0, 0>, static_mask_dims = array<i64: 1>}> : (tensor<1x!tt.ptr<bf16>>) -> tensor<1xbf16>

0 commit comments

Comments
 (0)