Skip to content

Commit 686a8c1

Browse files
authored
[triton-raise-block-ptr]: Fix failing test (#3219)
Fixes #3218 --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent aeba507 commit 686a8c1

File tree

6 files changed

+115
-78
lines changed

6 files changed

+115
-78
lines changed

test/Triton/Intel/RaiseToBlockPointers/addptr_cmpge.mlir

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: triton-opt %s -triton-raise-block-pointer --split-input-file -canonicalize | FileCheck %s
2-
// XFAIL: *
32

43
// These tests check that loads/stores that exhibit a cmp ge against 0 work
54
// correctly with the pointer analysis pass
@@ -45,7 +44,10 @@ tt.func public @test_masked_load(%arg0: !tt.ptr<f16>) -> tensor<16x16xf16> {
4544
}
4645

4746
// CHECK: tt.func public @test_masked_load([[arg0:%.+]]: !tt.ptr<f16>) -> tensor<16x16xf16> {
48-
// CHECK: [[VAR_0:%.+]] = tt.make_tensor_ptr [[arg0]], {{.*}} {order = array<i32>} : <tensor<16x16xf16>>
47+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
48+
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
49+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
50+
// CHECK: [[VAR_0:%.+]] = tt.make_tensor_ptr [[arg0]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<16x16xf16>>
4951
// CHECK: [[VAR_1:%.+]] = tt.load [[VAR_0]] evictionPolicy = evict_last : !tt.ptr<tensor<16x16xf16>>
5052
// CHECK: tt.return [[VAR_1]] : tensor<16x16xf16>
5153
// CHECK: }
@@ -71,6 +73,9 @@ tt.func public @test_masked_store(%arg0: !tt.ptr<f16>) {
7173

7274
// CHECK: tt.func public @test_masked_store([[arg0:%.+]]: !tt.ptr<f16>) {
7375
// CHECK-DAG: [[VAR_cst:%.+]] = arith.constant dense<1.500000e+01> : tensor<16x16xf16>
74-
// CHECK-DAG: [[VAR_0:%.+]] = tt.make_tensor_ptr [[arg0]], {{.*}} {order = array<i32>} : <tensor<16x16xf16>>
76+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
77+
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
78+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
79+
// CHECK: [[VAR_0:%.+]] = tt.make_tensor_ptr [[arg0]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<16x16xf16>>
7580
// CHECK: tt.store [[VAR_0]], [[VAR_cst]] : !tt.ptr<tensor<16x16xf16>>
7681
// CHECK: }

test/Triton/Intel/RaiseToBlockPointers/addptr_for_expand_ptr.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s
1+
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
22
// XFAIL: *
3+
// TODO: add support for tt.expand_dims in loops
34

45
module {
56
tt.func @kernel(

test/Triton/Intel/RaiseToBlockPointers/kernel-02-fused-softmax.mlir

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
2-
// XFAIL: *
32

43
module {
54
tt.func public @softmax_kernel_012345(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32) {
@@ -15,7 +14,7 @@ module {
1514
%8 = tt.splat %cst : f32 -> tensor<128xf32>
1615
// TODO: add back once masked loads are supported
1716
// %9 = tt.load %5, %7, %8 : tensor<128x!tt.ptr<f32>>
18-
%9 = tt.load %5, %7 : tensor<128x!tt.ptr<f32>>
17+
%9 = tt.load %5 : tensor<128x!tt.ptr<f32>>
1918
%10 = "tt.reduce"(%9) ({
2019
^bb0(%arg5: f32, %arg6: f32):
2120
%21 = arith.cmpf ogt, %arg5, %arg6 : f32
@@ -48,35 +47,26 @@ module {
4847
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
4948
// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32
5049
// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32
51-
52-
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index
53-
// CHECK-DAG: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, shape: [0], order: [] : <f32> to tensor<128x!tt.ptr<f32>>
54-
// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
55-
// CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_4_]], [[CST_128_]] : index
56-
// CHECK: [[VAR_6_:%.+]] = arith.maxsi [[VAR_5_]], [[CST_0_1_]] : index
57-
// CHECK: [[VAR_7_:%.+]] = "tts.load"([[VAR_3_]], [[VAR_6_]], [[CST_0_]]) <{operandSegmentSizes = array<i32: 1, 1, 1>, static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<128x!tt.ptr<f32>>, index, f32) -> tensor<128xf32>
58-
// CHECK: [[VAR_8_:%.+]] = "tt.reduce"([[VAR_7_]]) <{axis = 0 : i32}> ({
50+
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_1_]]] {{.*}} : <tensor<128xf32>>
51+
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<128xf32>>
52+
// CHECK: [[VAR_4_:%.+]] = "tt.reduce"([[VAR_3_]]) <{axis = 0 : i32}> ({
5953
// CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32):
60-
// CHECK: [[VAR_21_:%.+]] = arith.cmpf ogt, [[IN_0_]], [[IN_1_]] : f32
61-
// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[IN_0_]], [[IN_1_]] : f32
62-
// CHECK: tt.reduce.return [[VAR_22_]] : f32
54+
// CHECK: [[VAR_13_:%.+]] = arith.cmpf ogt, [[IN_0_]], [[IN_1_]] : f32
55+
// CHECK: [[VAR_14_:%.+]] = arith.select [[VAR_13_]], [[IN_0_]], [[IN_1_]] : f32
56+
// CHECK: tt.reduce.return [[VAR_14_]] : f32
6357
// CHECK: }) : (tensor<128xf32>) -> f32
64-
// CHECK: [[VAR_9_:%.+]] = tt.splat [[VAR_8_]] : f32 -> tensor<128xf32>
65-
// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_7_]], [[VAR_9_]] : tensor<128xf32>
66-
// CHECK: [[VAR_11_:%.+]] = math.exp [[VAR_10_]] : tensor<128xf32>
67-
// CHECK: [[VAR_12_:%.+]] = "tt.reduce"([[VAR_11_]]) <{axis = 0 : i32}> ({
58+
// CHECK: [[VAR_5_:%.+]] = tt.splat [[VAR_4_]] : f32 -> tensor<128xf32>
59+
// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_3_]], [[VAR_5_]] : tensor<128xf32>
60+
// CHECK: [[VAR_7_:%.+]] = math.exp [[VAR_6_]] : tensor<128xf32>
61+
// CHECK: [[VAR_8_:%.+]] = "tt.reduce"([[VAR_7_]]) <{axis = 0 : i32}> ({
6862
// CHECK: ^bb0([[IN_2_:%.+]]: f32, [[IN_3_:%.+]]: f32):
69-
// CHECK: [[VAR_21_1_:%.+]] = arith.addf [[IN_2_]], [[IN_3_]] : f32
70-
// CHECK: tt.reduce.return [[VAR_21_1_]] : f32
63+
// CHECK: [[VAR_13_1_:%.+]] = arith.addf [[IN_2_]], [[IN_3_]] : f32
64+
// CHECK: tt.reduce.return [[VAR_13_1_]] : f32
7165
// CHECK: }) : (tensor<128xf32>) -> f32
72-
// CHECK: [[VAR_13_:%.+]] = tt.splat [[VAR_12_]] : f32 -> tensor<128xf32>
73-
// CHECK-DAG: [[VAR_14_:%.+]] = arith.divf [[VAR_11_]], [[VAR_13_]] : tensor<128xf32>
74-
// CHECK-DAG: [[VAR_15_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32
75-
// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index
76-
// CHECK-DAG: [[VAR_17_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: {{.}}[[VAR_16_]]{{.}}, shape: [0], order: [] : <f32> to tensor<128x!tt.ptr<f32>>
77-
// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
78-
// CHECK: [[VAR_19_:%.+]] = arith.minsi [[VAR_18_]], [[CST_128_]] : index
79-
// CHECK: [[VAR_20_:%.+]] = arith.maxsi [[VAR_19_]], [[CST_0_1_]] : index
80-
// CHECK: "tts.store"([[VAR_17_]], [[VAR_14_]], [[VAR_20_]]) <{static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<128x!tt.ptr<f32>>, tensor<128xf32>, index) -> ()
66+
// CHECK: [[VAR_9_:%.+]] = tt.splat [[VAR_8_]] : f32 -> tensor<128xf32>
67+
// CHECK-DAG: [[VAR_10_:%.+]] = arith.divf [[VAR_7_]], [[VAR_9_]] : tensor<128xf32>
68+
// CHECK-DAG: [[VAR_11_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32
69+
// CHECK: [[VAR_12_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_11_]]] {{.*}} : <tensor<128xf32>>
70+
// CHECK: tt.store [[VAR_12_]], [[VAR_10_]] : !tt.ptr<tensor<128xf32>>
8171
// CHECK: tt.return
8272
// CHECK: }
Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
2-
// XFAIL: *
32

43
// IR from python/examples/sign_extend.py
54
module {
@@ -16,7 +15,9 @@ module {
1615
%8 = arith.cmpi slt, %5, %7 : tensor<4xi64>
1716
%9 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x!tt.ptr<f32>>
1817
%10 = tt.addptr %9, %5 : tensor<4x!tt.ptr<f32>>, tensor<4xi64>
19-
%11 = tt.load %10, %8, %cst : tensor<4x!tt.ptr<f32>>
18+
%11 = tt.load %10 : tensor<4x!tt.ptr<f32>>
19+
// TODO: uncomment once masked loads are supported
20+
// %11 = tt.load %10, %8, %cst : tensor<4x!tt.ptr<f32>>
2021
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<4x!tt.ptr<f32>>
2122
%13 = tt.addptr %12, %2 : tensor<4x!tt.ptr<f32>>, tensor<4xi32>
2223
tt.store %13, %11 : tensor<4x!tt.ptr<f32>>
@@ -25,18 +26,13 @@ module {
2526
}
2627

2728
// CHECK: tt.func public @sign_extend([[PARAM_0_:%.+]]: !tt.ptr<i32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: !tt.ptr<f32>, [[PARAM_3_:%.+]]: i32) attributes {noinline = false} {
28-
// CHECK-DAG: [[CST_1_dot_100000_:%.+]] = arith.constant 1.100000e+01 : f32
29-
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index
30-
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = tt.load [[PARAM_0_]] : !tt.ptr<i32>
31-
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[LOAD_PARAM_0_MEM_]] : i32 to index
32-
// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4], strides: [1], offsets: {{.}}[[VAR_1_]]{{.}}, shape: [0], order: [] : <f32> to tensor<4x!tt.ptr<f32>>
33-
// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_1_]], [[CST_4_]] : index
34-
// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
35-
// CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_3_]], [[VAR_4_]] : index
36-
// CHECK: [[VAR_6_:%.+]] = arith.maxsi [[VAR_5_]], [[VAR_1_]] : index
37-
// CHECK: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[VAR_1_]] : index
38-
// CHECK-DAG: [[VAR_8_:%.+]] = "tts.load"([[VAR_2_]], [[VAR_7_]], [[CST_1_dot_100000_]]) <{operandSegmentSizes = array<i32: 1, 1, 1>, static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<4x!tt.ptr<f32>>, index, f32) -> tensor<4xf32>
39-
// CHECK-DAG: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [4], strides: [1], offsets: [0], shape: [0], order: [] : <f32> to tensor<4x!tt.ptr<f32>>
40-
// CHECK: "tts.store"([[VAR_9_]], [[VAR_8_]]) <{static_mask_dims = array<i64>}> : (tensor<4x!tt.ptr<f32>>, tensor<4xf32>) -> ()
29+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
30+
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
31+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
32+
// CHECK-DAG: [[VAR_0_:%.+]] = tt.load [[PARAM_0_]] : !tt.ptr<i32>
33+
// CHECK-DAG: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_0_]]] {{.*}} : <tensor<4xf32>>
34+
// CHECK-DAG: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4xf32>>
35+
// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_0_i32]]] {{.*}} : <tensor<4xf32>>
36+
// CHECK: tt.store [[VAR_3_]], [[VAR_2_]] : !tt.ptr<tensor<4xf32>>
4137
// CHECK: tt.return
4238
// CHECK: }

test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
2-
// XFAIL: *
2+
33

44
// We currently do not support this kind of modulo pattern:
55
// (a + arrange(0, K)) % M
@@ -59,15 +59,15 @@ module {
5959
}
6060

6161
// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[arg0_:.+]]: !tt.ptr<f32>, [[arg1_:.+]]: !tt.ptr<f32>, [[arg2_:.+]]: i32, [[arg3_:.+]]: i32, [[arg4_:.+]]: i32, [[arg5_:.+]]: i32, [[arg6_:.+]]: i32, [[arg7_:.+]]: i32) {
62-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
62+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
6363
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-9.900000e+01> : tensor<4x4xf32>
64-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32
65-
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32
66-
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32
64+
// CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32
65+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
66+
// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
6767
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<2> : tensor<4x1xi32>
6868
// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<6> : tensor<4xi32>
6969
// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2> : tensor<4xi32>
70-
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32
70+
// CHECK-DAG: [[CST_4_i32:%.+]] = arith.constant 4 : i32
7171
// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
7272
// CHECK-NOT: separator of consecutive DAGs
7373
// CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_0_]], [[VAR_cst_2_]] : tensor<4xi32>
@@ -90,22 +90,27 @@ module {
9090
// CHECK-DAG: [[VAR_15_:%.+]] = tt.addptr [[VAR_14_]], [[VAR_13_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
9191
// CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
9292
// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[arg6_]] : i32 to index
93-
// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[arg7_]] : i32 to index
94-
// CHECK: [[VAR_19_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
95-
// CHECK-DAG: [[VAR_20_:%.+]] = tt.broadcast [[VAR_19_]] : tensor<4x1xi1> -> tensor<4x4xi1>
96-
// CHECK-DAG: [[VAR_21_:%.+]] = arith.muli [[arg4_]], [[CST_4_]] : i32
93+
// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : index to i64
94+
// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[arg7_]] : i32 to index
95+
// CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[VAR_19_]] : index to i64
96+
// CHECK-DAG: [[VAR_21_:%.+]] = arith.trunci [[VAR_18_]] : i64 to i32
97+
// CHECK-DAG: [[VAR_22_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_21_]] : i32
98+
// CHECK-DAG: [[VAR_23_:%.+]] = arith.trunci [[VAR_20_]] : i64 to i32
99+
// CHECK-DAG: [[VAR_24_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_23_]] : i32
100+
// CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_18_]], [[VAR_20_]]], {{\[}}[[VAR_22_]], [[VAR_24_]]] {{.*}} : <tensor<4x4xf32>>
101+
// CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
102+
// CHECK-DAG: [[VAR_27_:%.+]] = tt.broadcast [[VAR_26_]] : tensor<4x1xi1> -> tensor<4x4xi1>
103+
// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32
97104
// CHECK-NOT: separator of consecutive DAGs
98-
// CHECK-DAG: [[VAR_22_:%.+]] = tt.splat [[VAR_21_]] : i32 -> tensor<4x4xi32>
99-
// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[arg5_]], [[CST_4_]] : i32
105+
// CHECK-DAG: [[VAR_29_:%.+]] = tt.splat [[VAR_28_]] : i32 -> tensor<4x4xi32>
106+
// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32
100107
// CHECK-NOT: separator of consecutive DAGs
101-
// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_23_]] : i32 to index
102-
// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[CST_0_]]) -> (tensor<4x4x!tt.ptr<f32>>, index) : i32 {
103-
// CHECK-DAG: [[VAR_26_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4, 4], strides: {{.}}[[VAR_17_]], [[VAR_18_]]{{.}}, offsets: {{.}}[[VAR_arg10_]], [[CST_0_]]{{.}}, shape: [0, 0], order: [] : <f32> to tensor<4x4x!tt.ptr<f32>>
104-
// CHECK-DAG: [[LOAD_VAR_arg9_MEM_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_20_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
105-
// CHECK: "tts.store"([[VAR_26_]], [[LOAD_VAR_arg9_MEM_]]) <{static_mask_dims = array<i64>}> : (tensor<4x4x!tt.ptr<f32>>, tensor<4x4xf32>) -> ()
106-
// CHECK-DAG: [[VAR_28_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_22_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
107-
// CHECK-DAG: [[VAR_29_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_24_]] : index
108-
// CHECK: scf.yield [[VAR_28_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, index
108+
// CHECK-DAG: [[VAR_31_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_25_]]) -> (tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>)
109+
// CHECK: [[VAR_32_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_27_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
110+
// CHECK: tt.store [[VAR_arg10_]], [[VAR_32_]] : !tt.ptr<tensor<4x4xf32>>
111+
// CHECK-DAG: [[VAR_33_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
112+
// CHECK-DAG: [[VAR_34_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<4x4xf32>>
113+
// CHECK: scf.yield [[VAR_33_]], [[VAR_34_]] : tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>
109114
// CHECK: }
110115
// CHECK: tt.return
111116
// CHECK: }

0 commit comments

Comments
 (0)