Skip to content

Commit 2e65dc9

Browse files
timthluTim Lu
andauthored
Add device_assert support for tensors (#315)
Adds device_assert support for up to 3D tensors - Done by looping through every element of the tensor, extracting each element, and adding a `cf.assert` for each element - To make asserts work in the triton-shared CPU backend: compiling asserts results in symbols that are incompatible with shared libraries - requires compilation with PIC enabled Note: tested in both triton-shared and MAIA. Works in triton-shared, but requires a small change in MAIA in order to work. --------- Co-authored-by: Tim Lu <[email protected]>
1 parent 9ef5016 commit 2e65dc9

File tree

5 files changed

+176
-54
lines changed

5 files changed

+176
-54
lines changed

backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _llir_to_bin(llir: str, metadata):
170170
subprocess.check_call(subprocess_args)
171171
else:
172172
llc_path = _get_llvm_bin_path("llc")
173-
subprocess.check_call([llc_path, src_path, "-filetype=obj", "-o", dst_path])
173+
subprocess.check_call([llc_path, src_path, "-filetype=obj", "-relocation-model=pic", "-o", dst_path])
174174

175175
return Path(dst_path).read_bytes()
176176

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -790,25 +790,45 @@ struct AssertConverter : public OpConversionPattern<triton::AssertOp> {
790790
ConversionPatternRewriter &rewriter) const override {
791791
Value condVal = op.getCondition();
792792

793-
if (isa<mlir::TensorType>(condVal.getType())) {
794-
auto scalarVal = getScalarValue(op.getCondition(), op.getLoc(), rewriter);
795-
condVal = scalarVal ? scalarVal : condVal;
796-
}
797-
assert(condVal && isa<mlir::IntegerType>(condVal.getType()) &&
798-
"Only asserts on scalars are currently supported");
799-
800-
if (!condVal.getType().isInteger(1)) {
801-
auto zero =
802-
rewriter.create<mlir::arith::ConstantIntOp>(op.getLoc(), 0, 32);
803-
auto newCond = rewriter.create<mlir::arith::CmpIOp>(
804-
op.getLoc(), arith::CmpIPredicate::ne, condVal, zero);
805-
condVal = newCond.getResult();
806-
}
807-
808793
auto assertMessage =
809-
llvm::formatv("Assertion `{0}` failed", op.getMessage());
810-
rewriter.create<mlir::cf::AssertOp>(op.getLoc(), condVal,
811-
assertMessage.str());
794+
llvm::formatv("Assertion `{0}` failed", op.getMessage());
795+
796+
// The condition can only be I1 or I1Tensor (integer or tensor) from TritonOps.td.
797+
// Tensors will always be RankedTensorType.
798+
if (isa<mlir::IntegerType>(condVal.getType())) {
799+
// handle scalar case
800+
rewriter.create<mlir::cf::AssertOp>(op.getLoc(), condVal,
801+
assertMessage.str());
802+
} else if (auto tensorType = dyn_cast<RankedTensorType>(condVal.getType())) {
803+
// handle tensor case
804+
int64_t rank = tensorType.getRank();
805+
806+
// create identity mapping for access pattern
807+
SmallVector<AffineMap, 3> indexingMaps{AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext())};
808+
809+
// loops do not depend on each other
810+
SmallVector<utils::IteratorType, 3> iteratorTypes(rank, utils::IteratorType::parallel);
811+
812+
rewriter.create<linalg::GenericOp>(
813+
op.getLoc(),
814+
TypeRange{},
815+
condVal,
816+
ValueRange{},
817+
ArrayRef<AffineMap>{indexingMaps},
818+
ArrayRef<utils::IteratorType>{iteratorTypes},
819+
[&](OpBuilder &b, Location loc, ValueRange args) {
820+
// obtain the element in the tensor
821+
Value element = args[0];
822+
823+
// make a cf.assert for the current element
824+
b.create<mlir::cf::AssertOp>(loc, element, assertMessage.str());
825+
826+
b.create<linalg::YieldOp>(loc);
827+
});
828+
} else {
829+
op.emitError("Unexpected type in triton::AssertOp");
830+
return failure();
831+
}
812832

813833
rewriter.eraseOp(op);
814834
return success();
Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,50 @@
11
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s
2-
tt.func public @assert_lol(%arg0: i32) {
3-
%c0_i32 = arith.constant 0 : i32
4-
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
5-
%1 = tt.splat %0 : i1 -> tensor<1xi1>
6-
tt.assert %1, "lol" : tensor<1xi1>
2+
3+
// CHECK: #map = affine_map<(d0) -> (d0)>
4+
// CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1)>
5+
// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
6+
7+
tt.func public @assert_tensor_1d() {
8+
%0 = tensor.empty() : tensor<4xi1>
9+
tt.assert %0, "message" : tensor<4xi1>
710
tt.return
811
}
912

13+
// CHECK-LABEL: func.func @assert_tensor_1d
14+
// CHECK-NOT: tt.assert
15+
// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0 : tensor<4xi1>) {
16+
// CHECK: ^bb0(%in: i1):
17+
// CHECK: cf.assert %in, "Assertion `message` failed"
18+
// CHECK: linalg.yield
19+
// CHECK: }
20+
// CHECK-NOT: tt.assert
21+
22+
tt.func public @assert_tensor_2d() {
23+
%0 = tensor.empty() : tensor<4x4xi1>
24+
tt.assert %0, "message" : tensor<4x4xi1>
25+
tt.return
26+
}
27+
28+
// CHECK-LABEL: func.func @assert_tensor_2d
29+
// CHECK-NOT: tt.assert
30+
// CHECK: linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<4x4xi1>) {
31+
// CHECK: ^bb0(%in: i1):
32+
// CHECK: cf.assert %in, "Assertion `message` failed"
33+
// CHECK: linalg.yield
34+
// CHECK: }
35+
// CHECK-NOT: tt.assert
36+
37+
tt.func public @assert_tensor_3d() {
38+
%0 = tensor.empty() : tensor<4x4x4xi1>
39+
tt.assert %0, "message" : tensor<4x4x4xi1>
40+
tt.return
41+
}
1042

11-
// CHECK-LABEL: func.func @assert_lol
12-
// CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
13-
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : i32
14-
// CHECK: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32
15-
// CHECK: cf.assert [[VAR_0_]], "Assertion `lol` failed"
16-
// CHECK: return
17-
// CHECK: }
43+
// CHECK-LABEL: func.func @assert_tensor_3d
44+
// CHECK-NOT: tt.assert
45+
// CHECK: linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<4x4x4xi1>) {
46+
// CHECK: ^bb0(%in: i1):
47+
// CHECK: cf.assert %in, "Assertion `message` failed"
48+
// CHECK: linalg.yield
49+
// CHECK: }
50+
// CHECK-NOT: tt.assert
Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,50 @@
11
// RUN: triton-shared-opt --triton-arith-to-linalg %s | FileCheck %s
2-
tt.func public @assert_lol(%arg0: i32) {
3-
%c0_i32 = arith.constant 0 : i32
4-
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
5-
%1 = tt.splat %0 : i1 -> tensor<1xi1>
6-
tt.assert %1, "lol" : tensor<1xi1>
2+
3+
// CHECK: #map = affine_map<(d0) -> (d0)>
4+
// CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1)>
5+
// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
6+
7+
tt.func public @assert_tensor_1d() {
8+
%0 = tensor.empty() : tensor<4xi1>
9+
tt.assert %0, "message" : tensor<4xi1>
10+
tt.return
11+
}
12+
13+
// CHECK-LABEL: func.func @assert_tensor_1d
14+
// CHECK-NOT: tt.assert
15+
// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0 : tensor<4xi1>) {
16+
// CHECK: ^bb0(%in: i1):
17+
// CHECK: cf.assert %in, "Assertion `message` failed"
18+
// CHECK: linalg.yield
19+
// CHECK: }
20+
// CHECK-NOT: tt.assert
21+
22+
tt.func public @assert_tensor_2d() {
23+
%0 = tensor.empty() : tensor<4x4xi1>
24+
tt.assert %0, "message" : tensor<4x4xi1>
25+
tt.return
26+
}
27+
28+
// CHECK-LABEL: func.func @assert_tensor_2d
29+
// CHECK-NOT: tt.assert
30+
// CHECK: linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<4x4xi1>) {
31+
// CHECK: ^bb0(%in: i1):
32+
// CHECK: cf.assert %in, "Assertion `message` failed"
33+
// CHECK: linalg.yield
34+
// CHECK: }
35+
// CHECK-NOT: tt.assert
36+
37+
tt.func public @assert_tensor_3d() {
38+
%0 = tensor.empty() : tensor<4x4x4xi1>
39+
tt.assert %0, "message" : tensor<4x4x4xi1>
740
tt.return
841
}
942

10-
// CHECK-LABEL: func.func @assert_lol
11-
// CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
12-
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : i32
13-
// CHECK-DAG: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32
14-
// CHECK: cf.assert [[VAR_0_]], "Assertion `lol` failed"
15-
// CHECK: return
16-
// CHECK: }
43+
// CHECK-LABEL: func.func @assert_tensor_3d
44+
// CHECK-NOT: tt.assert
45+
// CHECK: linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<4x4x4xi1>) {
46+
// CHECK: ^bb0(%in: i1):
47+
// CHECK: cf.assert %in, "Assertion `message` failed"
48+
// CHECK: linalg.yield
49+
// CHECK: }
50+
// CHECK-NOT: tt.assert
Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,50 @@
11
// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s
2-
tt.func public @assert_lol(%arg0: i32) {
3-
%c0_i32 = arith.constant 0 : i32
4-
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
5-
%1 = tt.splat %0 : i1 -> tensor<1xi1>
6-
tt.assert %1, "lol": tensor<1xi1>
2+
3+
// CHECK: #map = affine_map<(d0) -> (d0)>
4+
// CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1)>
5+
// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
6+
7+
tt.func public @assert_tensor_1d() {
8+
%0 = tensor.empty() : tensor<4xi1>
9+
tt.assert %0, "message" : tensor<4xi1>
10+
tt.return
11+
}
12+
13+
// CHECK-LABEL: func.func @assert_tensor_1d
14+
// CHECK-NOT: tt.assert
15+
// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0 : tensor<4xi1>) {
16+
// CHECK: ^bb0(%in: i1):
17+
// CHECK: cf.assert %in, "Assertion `message` failed"
18+
// CHECK: linalg.yield
19+
// CHECK: }
20+
// CHECK-NOT: tt.assert
21+
22+
tt.func public @assert_tensor_2d() {
23+
%0 = tensor.empty() : tensor<4x4xi1>
24+
tt.assert %0, "message" : tensor<4x4xi1>
25+
tt.return
26+
}
27+
28+
// CHECK-LABEL: func.func @assert_tensor_2d
29+
// CHECK-NOT: tt.assert
30+
// CHECK: linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<4x4xi1>) {
31+
// CHECK: ^bb0(%in: i1):
32+
// CHECK: cf.assert %in, "Assertion `message` failed"
33+
// CHECK: linalg.yield
34+
// CHECK: }
35+
// CHECK-NOT: tt.assert
36+
37+
tt.func public @assert_tensor_3d() {
38+
%0 = tensor.empty() : tensor<4x4x4xi1>
39+
tt.assert %0, "message" : tensor<4x4x4xi1>
740
tt.return
841
}
942

10-
// CHECK: func.func @assert_lol(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
11-
// CHECK: %c0_i32 = arith.constant 0 : i32
12-
// CHECK: %0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
13-
// CHECK: cf.assert %0, "Assertion `lol` failed"
14-
// CHECK: return
15-
// CHECK: }
43+
// CHECK-LABEL: func.func @assert_tensor_3d
44+
// CHECK-NOT: tt.assert
45+
// CHECK: linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<4x4x4xi1>) {
46+
// CHECK: ^bb0(%in: i1):
47+
// CHECK: cf.assert %in, "Assertion `message` failed"
48+
// CHECK: linalg.yield
49+
// CHECK: }
50+
// CHECK-NOT: tt.assert

0 commit comments

Comments
 (0)