Skip to content

Commit b059f2c

Browse files
committed
Address code review comments
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 7f33969 commit b059f2c

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: triton-opt %s -split-input-file -triton-intel-remove-boundary-checks | FileCheck %s
2+
3+
module {
4+
tt.func public @simple_load(%load_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
5+
%c1_i64 = arith.constant 1 : i64
6+
%c64_i64 = arith.constant 64 : i64
7+
%c512_i64 = arith.constant 512 : i64
8+
%c1024_i64 = arith.constant 1024 : i64
9+
%c0_i32 = arith.constant 0 : i32
10+
%x = arith.constant 10 : i32
11+
%in = tt.make_tensor_ptr %load_ptr, [%c1_i64, %c64_i64, %c1024_i64], [%c512_i64, %c64_i64, %c1_i64], [%c0_i32, %c0_i32, %x] {order = array<i32: 2, 1, 0>} : <tensor<1x64x64xf16>>
12+
%load = tt.load %in {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x64x64xf16>>
13+
tt.return
14+
}
15+
// CHECK-LABEL: simple_load
16+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr
17+
// CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>>
18+
}
19+
20+
// -----
21+
22+
module {
23+
tt.func public @load_in_for_loop(%load_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %load_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
24+
%c0_i32 = arith.constant 0 : i32
25+
%c1_i32 = arith.constant 1 : i32
26+
%c20_i32 = arith.constant 20 : i32
27+
%c64_i32 = arith.constant 64 : i32
28+
%c1024_i32 = arith.constant 1024 : i32
29+
scf.for %x = %c0_i32 to %c20_i32 step %c1_i32 : i32 {
30+
%pid = tt.get_program_id x : i32
31+
%c0_i64 = arith.constant 0 : i64
32+
%c1_i64 = arith.constant 1 : i64
33+
%c512_i64 = arith.constant 512 : i64
34+
%c1024_i64 = arith.constant 1024 : i64
35+
%c64_i64 = arith.constant 64 : i64
36+
%c65536_i64 = arith.constant 65536 : i64
37+
%ptr0 = tt.make_tensor_ptr %load_ptr0, [%c512_i64, %c1024_i64, %c64_i64], [%c65536_i64, %c64_i64, %c1_i64], [%x, %pid, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
38+
%load0 = tt.load %ptr0 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x512x64xf16>>
39+
%9 = arith.bitcast %c0_i32 : i32 to i32
40+
%10 = arith.bitcast %c1024_i32 : i32 to i32
41+
%11 = arith.bitcast %c64_i32 : i32 to i32
42+
scf.for %z = %9 to %10 step %11 iter_args() -> () : i32 {
43+
%ptr1 = tt.make_tensor_ptr %load_ptr1, [%c512_i64, %c64_i64, %c1024_i64], [%c65536_i64, %c1_i64, %c64_i64], [%x, %c0_i32, %z] {order = array<i32: 2, 0, 1>} : <tensor<1x64x64xf16>>
44+
// a. boundaryCheck = 1 checks the block ptr offset at index 2 (%z)
45+
// b. boundaryCheck = 2 checks the block ptr offset at index 1 (%y)
46+
// Check (a) is unnecessary because max(%z) = 920 which is less than %s2 (1024)
47+
// Check (a) is trivially unnecessary because %y(zero) < %s1(64)
48+
%load1 = tt.load %ptr1 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x64x64xf16>>
49+
}
50+
}
51+
tt.return
52+
}
53+
// CHECK-LABEL: load_in_for_loop
54+
// CHECK-COUNT-2: scf.for
55+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr
56+
// CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>>
57+
}

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def TritonIntelRemoveBoundaryChecks
8080
%lb = arith.bitcast %c0_i32 : i32 to i32
8181
%ub = arith.bitcast %c1024_i32 : i32 to i32
8282
%st = arith.bitcast %c64_i32 : i32 to i32
83-
scf.for %i = %lb to %ub step %st : i32 {
83+
scf.for %iv = %lb to %ub step %st : i32 {
8484
%s0 = arith.constant 512 : i64
8585
%s1 = arith.constant 64 : i64
8686
%s2 = arith.constant 1024 : i64
@@ -98,7 +98,7 @@ def TritonIntelRemoveBoundaryChecks
9898
The transformation would drop the boundary check on the load operation because:
9999
- `%ptr` is never advanced in the loop
100100
- `%iv` has values [0, 64, 128, ..., 960], max(%iv) = 960
101-
- `%s2` is qual to 1014
101+
- `%s2` is equal to 1014
102102
- the boundary check expression `%iv` < `%s2` is always true
103103
}];
104104

third_party/intel/lib/Dialect/Triton/Transforms/RemoveBoundaryChecks.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#include "intel/include/Dialect/Triton/Transforms/Passes.h"
32
#include "intel/include/Utils/Utility.h"
43
#include "mlir/Dialect/Arith/IR/Arith.h"

0 commit comments

Comments
 (0)