|
| 1 | +// RUN: triton-shared-opt --triton-to-structured --split-input-file %s | FileCheck %s |
| 2 | + |
| 3 | +// These tests check that loads/stores that exhibit a cmp ge against 0 work |
| 4 | +// correctly with the pointer analysis pass |
| 5 | + |
| 6 | +// Example of the triton kernel that generates the loads/stores with cmp ge 0. |
| 7 | +// The boundary_check fields of the load/stores, along with preprocessing the |
| 8 | +// kernel through --triton-rewrite-tensor-pointer before calling the |
| 9 | +// --triton-to-structured pass results in those cmp ge 0 instructions. |
| 10 | +// |
| 11 | +// def kernel(in_ptr0, out_ptr0, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): |
| 12 | +// yoffset = tl.program_id(1) * YBLOCK |
| 13 | +// xoffset = tl.program_id(0) * XBLOCK |
| 14 | +// tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[16640, 10], |
| 15 | +// strides=[1, 16640], block_shape=[XBLOCK, YBLOCK], |
| 16 | +// order=[1, 0], offsets=[xoffset, yoffset]), |
| 17 | +// boundary_check=[0, 1]) |
| 18 | +// tl.store(tl.make_block_ptr(out_ptr0, shape=[16640, 10], |
| 19 | +// strides=[1, 16640], block_shape=[XBLOCK, YBLOCK], |
| 20 | +// order=[1, 0], offsets=[xoffset, yoffset]), |
| 21 | +// tl.broadcast_to(tmp0, [XBLOCK, YBLOCK]).to(tl.float16), |
| 22 | +// boundary_check=[0, 1]) |
| 23 | + |
| 24 | +tt.func public @test_masked_load(%arg0: !tt.ptr<f16>) -> tensor<16x16xf16> { |
| 25 | + %cst = arith.constant dense<0> : tensor<1x16xi64> |
| 26 | + %c16_i32 = arith.constant 16 : i32 |
| 27 | + %0 = tt.get_program_id y : i32 |
| 28 | + %1 = arith.muli %0, %c16_i32 : i32 |
| 29 | + %2 = arith.extsi %1 : i32 to i64 |
| 30 | + %3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x16x!tt.ptr<f16>> |
| 31 | + %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> |
| 32 | + %5 = arith.extsi %4 : tensor<16xi32> to tensor<16xi64> |
| 33 | + %6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> |
| 34 | + %7 = tt.broadcast %6 : tensor<16x1xi64> -> tensor<16x16xi64> |
| 35 | + %8 = tt.addptr %3, %7 : tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi64> |
| 36 | + %9 = tt.splat %2 : i64 -> tensor<16xi64> |
| 37 | + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> |
| 38 | + %11 = arith.extsi %10 : tensor<16xi32> to tensor<16xi64> |
| 39 | + %12 = arith.addi %9, %11 : tensor<16xi64> |
| 40 | + %13 = tt.expand_dims %12 {axis = 0 : i32} : tensor<16xi64> -> tensor<1x16xi64> |
| 41 | + %14 = arith.cmpi sge, %13, %cst : tensor<1x16xi64> |
| 42 | + %15 = tt.broadcast %14 : tensor<1x16xi1> -> tensor<16x16xi1> |
| 43 | + %16 = tt.load %8, %15 evictionPolicy = evict_last : tensor<16x16x!tt.ptr<f16>> |
| 44 | + tt.return %16 : tensor<16x16xf16> |
| 45 | +} |
| 46 | + |
| 47 | +// CHECK: tt.func public @test_masked_load([[arg0_:%.+]]: !tt.ptr<f16>) -> tensor<16x16xf16> { |
| 48 | +// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [16, 16], strides: [1, 0], offsets: [0, 0], shape: [0, 0], order: [] : <f16> to tensor<16x16x!tt.ptr<f16>> |
| 49 | +// CHECK: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array<i32: 1, 0, 0>, static_mask_dims = array<i64: 16, 16>}> : (tensor<16x16x!tt.ptr<f16>>) -> tensor<16x16xf16> |
| 50 | +// CHECK: } |
| 51 | + |
| 52 | +// ----- |
| 53 | + |
| 54 | +tt.func public @test_masked_store(%arg0: !tt.ptr<f16>) { |
| 55 | + %cst = arith.constant dense<0> : tensor<16x1xi64> |
| 56 | + %cst_0 = arith.constant dense<1.500000e+01> : tensor<16x16xf16> |
| 57 | + %0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x16x!tt.ptr<f16>> |
| 58 | + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> |
| 59 | + %2 = arith.extsi %1 : tensor<16xi32> to tensor<16xi64> |
| 60 | + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> |
| 61 | + %4 = tt.broadcast %3 : tensor<16x1xi64> -> tensor<16x16xi64> |
| 62 | + %5 = tt.addptr %0, %4 : tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi64> |
| 63 | + %6 = arith.cmpi sge, %3, %cst : tensor<16x1xi64> |
| 64 | + %7 = tt.broadcast %6 : tensor<16x1xi1> -> tensor<16x16xi1> |
| 65 | + tt.store %5, %cst_0, %7 : tensor<16x16x!tt.ptr<f16>> |
| 66 | + tt.return |
| 67 | +} |
| 68 | + |
| 69 | +// CHECK: tt.func public @test_masked_store([[arg0_:%.+]]: !tt.ptr<f16>) { |
| 70 | +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.500000e+01> : tensor<16x16xf16> |
| 71 | +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [16, 16], strides: [1, 0], offsets: [0, 0], shape: [0, 0], order: [] : <f16> to tensor<16x16x!tt.ptr<f16>> |
| 72 | +// CHECK: "tts.store"([[VAR_0_]], [[VAR_cst_]]) <{static_mask_dims = array<i64: 16, 16>}> : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xf16>) -> () |
| 73 | +// CHECK: } |
0 commit comments