11// RUN: triton-opt --split-input-file %s -triton-licm | FileCheck %s
22
3- tt.func @hoist_load_without_mask (%arg0: tensor <1024 x!tt.ptr <f32 >>, %arg1: tensor <1024 xi32 >, %arg2: tensor <1024 xi32 >, %arg3: i32 , %arg4 : i32 , %arg5: tensor <1024 x!tt.ptr <f32 >>) {
3+ tt.func @hoist_load_without_mask1 (%arg0: tensor <1024 x!tt.ptr <f32 >>, %arg1: tensor <1024 xi32 >, %arg2: tensor <1024 xi32 >, %arg3: i32 , %arg4 : i32 , %arg5: tensor <1024 x!tt.ptr <f32 >>) {
44 %cst = arith.constant dense <0.000000e+00 > : tensor <1024 xf32 >
55 %c1_i32 = arith.constant 1 : i32
66 // Check if the load is hoisted
7- // CHECK-LABEL: hoist_load_without_mask
7+ // CHECK-LABEL: hoist_load_without_mask1
88 // CHECK: %[[TRIP_COUNT_CMP:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
99 // CHECK: %[[SPLAT:.*]] = tt.splat %[[TRIP_COUNT_CMP]]
1010 // CHECK: %[[LOAD:.*]] = tt.load %[[_:.*]], %[[SPLAT]]
@@ -23,6 +23,29 @@ tt.func @hoist_load_without_mask(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor
2323
2424// -----
2525
26+ tt.func @hoist_load_without_mask2 (%arg0: !tt.ptr <tensor <1024 xf32 >>, %arg3: i32 , %arg4 : i32 , %arg5: !tt.ptr <tensor <1024 xf32 >>) {
27+ %cst = arith.constant dense <0.000000e+00 > : tensor <1024 xf32 >
28+ %c1_i32 = arith.constant 1 : i32
29+ // Check if the load is hoisted
30+ // CHECK-LABEL: hoist_load_without_mask2
31+ // CHECK: %[[TRIP_COUNT_CMP:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
32+ // CHECK: %[[SPLAT:.*]] = tt.splat %[[TRIP_COUNT_CMP]]
33+ // CHECK: %[[LOAD:.*]] = tt.load %[[_:.*]], %[[SPLAT]]
34+ // CHECK: arith.addf %[[LOAD]], %[[LOAD]]
35+ // CHECK: scf.for
36+ // CHECK-NOT: tt.load
37+ %1 = scf.for %arg7 = %arg3 to %arg4 step %c1_i32 iter_args (%arg6 = %cst ) -> (tensor <1024 xf32 >) : i32 {
38+ %2 = tt.load %arg0 : !tt.ptr <tensor <1024 xf32 >>
39+ %3 = arith.addf %2 , %2 : tensor <1024 xf32 >
40+ %4 = arith.addf %arg6 , %3 : tensor <1024 xf32 >
41+ scf.yield %4 : tensor <1024 xf32 >
42+ }
43+ tt.store %arg5 , %1 : !tt.ptr <tensor <1024 xf32 >>
44+ tt.return
45+ }
46+
47+ // -----
48+
2649tt.func @hoist_two_loads_without_mask (%arg0: tensor <1024 x!tt.ptr <f32 >>, %arg1: tensor <1024 xi32 >, %arg2: tensor <1024 xi32 >, %arg3: i32 , %arg4 : i32 , %arg5: tensor <1024 x!tt.ptr <f32 >>, %arg6: tensor <1024 x!tt.ptr <f32 >>) {
2750 %cst = arith.constant dense <0.000000e+00 > : tensor <1024 xf32 >
2851 %c1_i32 = arith.constant 1 : i32
0 commit comments