11// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
2- // XFAIL: *
2+
33
44// We currently do not support this kind of modulo pattern:
55// (a + arrange(0, K)) % M
@@ -59,15 +59,15 @@ module {
5959}
6060
6161// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[arg0_:.+]]: !tt.ptr<f32>, [[arg1_:.+]]: !tt.ptr<f32>, [[arg2_:.+]]: i32, [[arg3_:.+]]: i32, [[arg4_:.+]]: i32, [[arg5_:.+]]: i32, [[arg6_:.+]]: i32, [[arg7_:.+]]: i32) {
62- // CHECK-DAG: [[CST_0_ :%.+]] = arith.constant 0 : index
62+ // CHECK-DAG: [[CST_0_i64 :%.+]] = arith.constant 0 : i64
6363// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-9.900000e+01> : tensor<4x4xf32>
64- // CHECK-DAG: [[CST_1_ :%.+]] = arith.constant 1 : i32
65- // CHECK-DAG: [[CST_0_1_ :%.+]] = arith.constant 0 : i32
66- // CHECK-DAG: [[CST_2_ :%.+]] = arith.constant 2 : i32
64+ // CHECK-DAG: [[CST_1_i32 :%.+]] = arith.constant 1 : i32
65+ // CHECK-DAG: [[CST_0_i32 :%.+]] = arith.constant 0 : i32
66+ // CHECK-DAG: [[CST_2_i32 :%.+]] = arith.constant 2 : i32
6767// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<2> : tensor<4x1xi32>
6868// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<6> : tensor<4xi32>
6969// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2> : tensor<4xi32>
70- // CHECK-DAG: [[CST_4_ :%.+]] = arith.constant 4 : i32
70+ // CHECK-DAG: [[CST_4_i32 :%.+]] = arith.constant 4 : i32
7171// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
7272// CHECK-NOT: separator of consecutive DAGs
7373// CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_0_]], [[VAR_cst_2_]] : tensor<4xi32>
@@ -90,22 +90,27 @@ module {
9090// CHECK-DAG: [[VAR_15_:%.+]] = tt.addptr [[VAR_14_]], [[VAR_13_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
9191// CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
9292// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[arg6_]] : i32 to index
93- // CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[arg7_]] : i32 to index
94- // CHECK: [[VAR_19_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
95- // CHECK-DAG: [[VAR_20_:%.+]] = tt.broadcast [[VAR_19_]] : tensor<4x1xi1> -> tensor<4x4xi1>
96- // CHECK-DAG: [[VAR_21_:%.+]] = arith.muli [[arg4_]], [[CST_4_]] : i32
93+ // CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : index to i64
94+ // CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[arg7_]] : i32 to index
95+ // CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[VAR_19_]] : index to i64
96+ // CHECK-DAG: [[VAR_21_:%.+]] = arith.trunci [[VAR_18_]] : i64 to i32
97+ // CHECK-DAG: [[VAR_22_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_21_]] : i32
98+ // CHECK-DAG: [[VAR_23_:%.+]] = arith.trunci [[VAR_20_]] : i64 to i32
99+ // CHECK-DAG: [[VAR_24_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_23_]] : i32
100+ // CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_18_]], [[VAR_20_]]], {{\[}}[[VAR_22_]], [[VAR_24_]]] {{.*}} : <tensor<4x4xf32>>
101+ // CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
102+ // CHECK-DAG: [[VAR_27_:%.+]] = tt.broadcast [[VAR_26_]] : tensor<4x1xi1> -> tensor<4x4xi1>
103+ // CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32
97104// CHECK-NOT: separator of consecutive DAGs
98- // CHECK-DAG: [[VAR_22_ :%.+]] = tt.splat [[VAR_21_ ]] : i32 -> tensor<4x4xi32>
99- // CHECK-DAG: [[VAR_23_ :%.+]] = arith.muli [[arg5_]], [[CST_4_ ]] : i32
105+ // CHECK-DAG: [[VAR_29_ :%.+]] = tt.splat [[VAR_28_ ]] : i32 -> tensor<4x4xi32>
106+ // CHECK-DAG: [[VAR_30_ :%.+]] = arith.muli [[arg5_]], [[CST_4_i32 ]] : i32
100107// CHECK-NOT: separator of consecutive DAGs
101- // CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_23_]] : i32 to index
102- // CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[CST_0_]]) -> (tensor<4x4x!tt.ptr<f32>>, index) : i32 {
103- // CHECK-DAG: [[VAR_26_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4, 4], strides: {{.}}[[VAR_17_]], [[VAR_18_]]{{.}}, offsets: {{.}}[[VAR_arg10_]], [[CST_0_]]{{.}}, shape: [0, 0], order: [] : <f32> to tensor<4x4x!tt.ptr<f32>>
104- // CHECK-DAG: [[LOAD_VAR_arg9_MEM_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_20_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
105- // CHECK: "tts.store"([[VAR_26_]], [[LOAD_VAR_arg9_MEM_]]) <{static_mask_dims = array<i64>}> : (tensor<4x4x!tt.ptr<f32>>, tensor<4x4xf32>) -> ()
106- // CHECK-DAG: [[VAR_28_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_22_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
107- // CHECK-DAG: [[VAR_29_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_24_]] : index
108- // CHECK: scf.yield [[VAR_28_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, index
108+ // CHECK-DAG: [[VAR_31_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_25_]]) -> (tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>)
109+ // CHECK: [[VAR_32_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_27_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
110+ // CHECK: tt.store [[VAR_arg10_]], [[VAR_32_]] : !tt.ptr<tensor<4x4xf32>>
111+ // CHECK-DAG: [[VAR_33_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
112+ // CHECK-DAG: [[VAR_34_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<4x4xf32>>
113+ // CHECK: scf.yield [[VAR_33_]], [[VAR_34_]] : tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>
109114// CHECK: }
110115// CHECK: tt.return
111116// CHECK: }
0 commit comments