Skip to content

Commit 18932d1

Browse files
Merge commit '493f9917c21061164de2df08ba75c7ba8da3130a'
2 parents 654f827 + 493f991 commit 18932d1

File tree

13 files changed

+122
-77
lines changed

13 files changed

+122
-77
lines changed

.github/workflows/integration-tests.yml

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ jobs:
5252
with:
5353
files: |
5454
cmake/*.txt
55+
cmake/*.json
5556
- name: Detect if enough time has passed since last post-submit run
5657
id: detect-time
5758
if: github.event_name == 'push'
@@ -127,7 +128,7 @@ jobs:
127128
- name: Compute hash of pre-commit config
128129
id: cache-key
129130
run: |
130-
echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml)" >> $GITHUB_OUTPUT
131+
echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
131132
shell: bash
132133
- name: Cache pre-commit's cache dir
133134
uses: actions/cache@v4
@@ -165,9 +166,20 @@ jobs:
165166
- name: Compute cache keys
166167
id: cache-key
167168
run: |
168-
echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT
169-
echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT
170-
echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT
169+
llvm_file="cmake/llvm-hash.txt"
170+
nvidia_file="cmake/nvidia-toolchain-version.json"
171+
json_file="cmake/json-version.txt"
172+
173+
# Check if files exist before proceeding
174+
if [[ ! -f "$llvm_file" || ! -f "$nvidia_file" || ! -f "$json_file" ]]; then
175+
echo "Error: Required dependency files are missing."
176+
exit 1
177+
fi
178+
179+
# Process the files if they exist
180+
echo "llvm=$(cat $llvm_file | cut -c 1-8)" >> $GITHUB_OUTPUT
181+
echo "nvidia=$(sha256sum $nvidia_file | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
182+
echo "json=$(cat $json_file)" >> $GITHUB_OUTPUT
171183
echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT
172184
shell: bash
173185
- name: Cache build dependencies
@@ -306,9 +318,20 @@ jobs:
306318
- name: Compute cache keys
307319
id: cache-key
308320
run: |
309-
echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT
310-
echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT
311-
echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT
321+
llvm_file="cmake/llvm-hash.txt"
322+
nvidia_file="cmake/nvidia-toolchain-version.json"
323+
json_file="cmake/json-version.txt"
324+
325+
# Check if files exist before proceeding
326+
if [[ ! -f "$llvm_file" || ! -f "$nvidia_file" || ! -f "$json_file" ]]; then
327+
echo "Error: Required dependency files are missing."
328+
exit 1
329+
fi
330+
331+
# Process the files if they exist
332+
echo "llvm=$(cat $llvm_file | cut -c 1-8)" >> $GITHUB_OUTPUT
333+
echo "nvidia=$(sha256sum $nvidia_file | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
334+
echo "json=$(cat $json_file)" >> $GITHUB_OUTPUT
312335
echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT
313336
shell: bash
314337
- name: Cache build dependencies
@@ -441,9 +464,20 @@ jobs:
441464
- name: Compute cache keys
442465
id: cache-key
443466
run: |
444-
echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT
445-
echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT
446-
echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT
467+
llvm_file="cmake/llvm-hash.txt"
468+
nvidia_file="cmake/nvidia-toolchain-version.json"
469+
json_file="cmake/json-version.txt"
470+
471+
# Check if files exist before proceeding
472+
if [[ ! -f "$llvm_file" || ! -f "$nvidia_file" || ! -f "$json_file" ]]; then
473+
echo "Error: Required dependency files are missing."
474+
exit 1
475+
fi
476+
477+
# Process the files if they exist
478+
echo "llvm=$(cat $llvm_file | cut -c 1-8)" >> $GITHUB_OUTPUT
479+
echo "nvidia=$(sha256sum $nvidia_file | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
480+
echo "json=$(cat $json_file)" >> $GITHUB_OUTPUT
447481
echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT
448482
shell: bash
449483
- name: Cache build dependencies

.github/workflows/integration-tests.yml.in

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ jobs:
5858
with:
5959
files: |
6060
cmake/*.txt
61+
cmake/*.json
6162

6263
- name: Detect if enough time has passed since last post-submit run
6364
id: detect-time
@@ -140,7 +141,7 @@ jobs:
140141
- name: Compute hash of pre-commit config
141142
id: cache-key
142143
run: |
143-
echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml)" >> $GITHUB_OUTPUT
144+
echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
144145
shell: bash
145146

146147
- name: Cache pre-commit's cache dir
@@ -188,9 +189,20 @@ jobs:
188189
name: Compute cache keys
189190
id: cache-key
190191
run: |
191-
echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT
192-
echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT
193-
echo "json=$(cat cmake/json-version.txt)" >> $GITHUB_OUTPUT
192+
llvm_file="cmake/llvm-hash.txt"
193+
nvidia_file="cmake/nvidia-toolchain-version.json"
194+
json_file="cmake/json-version.txt"
195+
196+
# Check if files exist before proceeding
197+
if [[ ! -f "$llvm_file" || ! -f "$nvidia_file" || ! -f "$json_file" ]]; then
198+
echo "Error: Required dependency files are missing."
199+
exit 1
200+
fi
201+
202+
# Process the files if they exist
203+
echo "llvm=$(cat $llvm_file | cut -c 1-8)" >> $GITHUB_OUTPUT
204+
echo "nvidia=$(sha256sum $nvidia_file | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
205+
echo "json=$(cat $json_file)" >> $GITHUB_OUTPUT
194206
echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT
195207
shell: bash
196208

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,11 +844,11 @@ def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrit
844844
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
845845
let summary = "Device-side assert, as in CUDA for correctness checking";
846846
let description = [{
847-
`tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number.
847+
`tt.assert` takes a condition tensor and a message string.
848848
If the condition is false, the message is printed, and the program is aborted.
849849
}];
850-
let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line);
851-
let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)";
850+
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
851+
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
852852
}
853853

854854
//

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,30 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3434
return failure();
3535
}
3636
}
37-
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
38-
adaptor.getFunc(), adaptor.getLine(), rewriter);
37+
llAssert(op, condition, adaptor.getMessage(), rewriter);
3938
rewriter.eraseOp(op);
4039
return success();
4140
}
4241
// op: the op at which the assert is inserted. Unlike printf, we need to
4342
// know about the op to split the block.
4443
void llAssert(Operation *op, Value condition, StringRef message,
45-
StringRef file, StringRef func, int line,
4644
ConversionPatternRewriter &rewriter) const {
4745
ConversionPatternRewriter::InsertionGuard guard(rewriter);
46+
4847
auto ctx = rewriter.getContext();
4948
auto loc = op->getLoc();
49+
50+
StringRef file = "unknown";
51+
StringRef func = "unknown";
52+
int line = 0;
53+
int col = 0;
54+
55+
if (auto fileLineColLoc = dyn_cast<FileLineColLoc>(loc)) {
56+
file = fileLineColLoc.getFilename();
57+
line = fileLineColLoc.getLine();
58+
col = fileLineColLoc.getColumn();
59+
}
60+
5061
// #block1
5162
// if (condition) {
5263
// #block2

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -657,18 +657,25 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
657657
// Emit different versions of the induction variable. They will be
658658
// removed by dead code if not used.
659659

660-
// bounds_range = ub - lb
661-
// total_iterations = (bounds_range + step - 1) / step
660+
// range_diff = ub - lb
661+
// total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
662662
Type t = lb.getType();
663-
Value minus1 =
664-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
665-
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
666-
Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
667-
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
668-
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
669-
670663
Value zero =
671664
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
665+
Value one =
666+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
667+
Value minusOne =
668+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
669+
Value stepLessZero = rewriter.create<arith::CmpIOp>(
670+
loc, arith::CmpIPredicate::slt, step, zero);
671+
Value stepDecr =
672+
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
673+
674+
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
675+
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
676+
Value rangeDecr =
677+
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
678+
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
672679

673680
// Capture predicates for dynamic loops.
674681
SmallVector<Value> predicates(maxStage + 1);
@@ -679,7 +686,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
679686
Value minusI =
680687
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
681688
Value iterI = rewriter.create<arith::AddIOp>(
682-
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
689+
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
683690
minusI);
684691
// newLastIter = lb + step * iterI
685692
Value newlastIter = rewriter.create<arith::AddIOp>(

python/src/ir.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,17 +1562,10 @@ void init_triton_ir(py::module &&m) {
15621562
})
15631563
.def("create_assert",
15641564
[](TritonOpBuilder &self, Value &condition,
1565-
const std::string &message, const std::string &fileName,
1566-
const std::string &funcName, unsigned lineNo) -> void {
1565+
const std::string &message) -> void {
15671566
auto messageAttr = StringAttr::get(self.getBuilder().getContext(),
15681567
llvm::StringRef(message));
1569-
auto fileNameAttr = StringAttr::get(self.getBuilder().getContext(),
1570-
llvm::StringRef(fileName));
1571-
auto funcNameAttr = StringAttr::get(self.getBuilder().getContext(),
1572-
llvm::StringRef(funcName));
1573-
auto lineNoAttr = self.getBuilder().getI32IntegerAttr(lineNo);
1574-
self.create<AssertOp>(condition, messageAttr, fileNameAttr,
1575-
funcNameAttr, lineNoAttr);
1568+
self.create<AssertOp>(condition, messageAttr);
15761569
})
15771570
.def("create_assume",
15781571
[](TritonOpBuilder &self, Value &condition) {

python/test/unit/language/test_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4977,6 +4977,7 @@ def kernel(Out):
49774977

49784978
intermediate_layouts = [
49794979
None,
4980+
SharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]),
49804981
SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
49814982
SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
49824983
SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),

python/triton/language/core.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,25 +2308,7 @@ def device_assert(cond, msg="", _builder=None):
23082308
:param msg: the message to print if the assertion fails. This is required to be a string literal.
23092309
'''
23102310
msg = _constexpr_to_value(msg)
2311-
import inspect
2312-
frame = inspect.currentframe()
2313-
module = inspect.getmodule(frame)
2314-
# The triton function module doesn't have the name attribute.
2315-
# We use this trick to find the caller.
2316-
while hasattr(module, "__name__"):
2317-
frame = frame.f_back
2318-
module = inspect.getmodule(frame)
2319-
lineno = 0
2320-
func_name = 'unknown'
2321-
file_name = 'unknown'
2322-
if frame is not None and frame.f_back is not None:
2323-
func_name = frame.f_code.co_name
2324-
file_name = frame.f_back.f_code.co_filename
2325-
# TODO: The line number currently indicates the line
2326-
# where the triton function is called but not where the
2327-
# device_assert is called. Need to enhance this.
2328-
lineno = frame.f_back.f_lineno
2329-
return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
2311+
return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder)
23302312

23312313

23322314
@builtin

python/triton/language/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,12 +1636,12 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
16361636
return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void)
16371637

16381638

1639-
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
1639+
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
16401640
cond_ty = cond.type
16411641
if not cond_ty.is_block():
16421642
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
16431643
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
1644-
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
1644+
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
16451645

16461646

16471647
def assume(cond, builder: ir.builder) -> tl.tensor:

python/triton/runtime/interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,9 +643,9 @@ def create_print(self, prefix, hex, values, isSigned):
643643
if hex:
644644
np.set_printoptions(formatter=None)
645645

646-
def create_assert(self, condition, message, fileName, funcName, lineNo):
646+
def create_assert(self, condition, message):
647647
# Interpreter's device_assert function has a different format than Triton's device_assert
648-
assert condition, f"{message} in {fileName}:{funcName}:{lineNo}"
648+
assert condition, f"{message}"
649649

650650
def create_assume(self, condition):
651651
assert condition, "Assume failed"

0 commit comments

Comments
 (0)