|
1 | 1 | // RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
|
2 | 2 |
|
3 | 3 | #dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
|
4 |
| -#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> |
5 |
| -#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> |
6 | 4 | module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
|
7 |
| - tt.func public @matmul_no_scf_with_advance_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64) { |
| 5 | + tt.func public @matmul_no_scf_with_advance_kernel(%base: !tt.ptr<f16>, %width: i64, %height: i64, %rowStride: i64) { |
8 | 6 | %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #dpas>
|
9 |
| - %c32_i32 = arith.constant 32 : i32 |
10 |
| - %c-64_i32 = arith.constant -64 : i32 |
11 |
| - %c-32_i32 = arith.constant -32 : i32 |
12 |
| - %c64_i32 = arith.constant 64 : i32 |
13 | 7 | %c0_i32 = arith.constant 0 : i32
|
14 | 8 | %c1_i64 = arith.constant 1 : i64
|
15 |
| - %13 = tt.make_tensor_ptr %arg2, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>> |
| 9 | + %0 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>> |
16 | 10 | // CHECK: %[[WARP_ID:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
|
17 | 11 | // CHECK: %[[offsetBaseY:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
|
18 | 12 | // CHECK: %[[offsetBaseX:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
|
@@ -42,7 +36,58 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
|
42 | 36 | // CHECK: llvm.mlir.undef : vector<8xf16>
|
43 | 37 | // CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
|
44 | 38 | // CHECK: triton_gen.2Dblockstore {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
|
45 |
| - tt.store %13, %cst {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #dpas>> |
| 39 | + tt.store %0, %cst {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #dpas>> |
| 40 | + tt.return |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +// ----- |
| 45 | + |
| 46 | +#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> |
| 47 | +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} { |
| 48 | + tt.func public @no_boundary_check(%base: !tt.ptr<f16>, %width: i64, %height: i64, %rowStride: i64) { |
| 49 | + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #dpas> |
| 50 | + %c0_i32 = arith.constant 0 : i32 |
| 51 | + %c1_i64 = arith.constant 1 : i64 |
| 52 | + %0 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>> |
| 53 | + |
| 54 | + // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 |
| 55 | + // CHECK: %[[WARP_ID:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 |
| 56 | + |
| 57 | + // CHECK: %[[offsetBaseY:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> |
| 58 | + // CHECK: %[[offsetBaseX:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> |
| 59 | + // CHECK: %[[baseHeight:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> |
| 60 | + // CHECK: %[[baseWidth:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> |
| 61 | + // CHECK: %[[rowStride:.*]] = llvm.extractvalue {{.*}}[4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> |
| 62 | + // CHECK: %[[colStride:.*]] = llvm.extractvalue {{.*}}[5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> |
| 63 | + // CHECK: %[[base:.*]] = llvm.extractvalue {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> |
| 64 | + |
| 65 | + // CHECK: %[[rowStride_i32:.*]] = llvm.trunc %[[rowStride]] : i64 to i32 |
| 66 | + // CHECK: %[[PITCH:.*]] = llvm.mul %[[rowStride_i32]], %[[C2]] |
| 67 | + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(f16, f16, {{.*}})> |
| 68 | + |
| 69 | + // COM: Skip the register, lane, warp and block to the offset computation which should be covered by the LL tests. |
| 70 | + // CHECK: %[[OFFSET_X:.*]] = llvm.add %[[offsetBaseX]], {{.*}} : i32 |
| 71 | + // CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[offsetBaseY]], {{.*}} : i32 |
| 72 | + |
| 73 | + // COM: When boundary check is absent: |
| 74 | + // CHECK: %[[baseWidth:.*]] = llvm.mlir.constant(64 : i32) |
| 75 | + // CHECK: %[[base1:.*]] = llvm.getelementptr %[[base]][%[[OFFSET_X]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i16 |
| 76 | + // CHECK: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32 |
| 77 | + // CHECK: %[[baseHeight:.*]] = llvm.mlir.constant(8 : i32) |
| 78 | + // CHECK: %[[OFF:.*]] = llvm.mul %[[OFFSET_Y]], %[[PITCH]] : i32 |
| 79 | + // CHECK: %[[base:.*]] = llvm.getelementptr %[[base1]][%[[OFF]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8 |
| 80 | + // CHECK: %[[OFFSET_Y:.*]] = llvm.mlir.constant(0 : i32) : i32 |
| 81 | + |
| 82 | + // CHECK: llvm.mlir.undef : vector<8xf16> |
| 83 | + // CHECK-COUNT-7: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16> |
| 84 | + // CHECK: %[[VAL0:.*]] = llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16> |
| 85 | + // CHECK: %[[VAL:.*]] = llvm.bitcast %[[VAL0]] : vector<8xf16> to vector<8xi16> |
| 86 | + |
| 87 | + // CHECK: triton_gen.2Dblockstore %[[base]], %[[baseWidth]], %[[baseHeight]], %[[PITCH]], %[[OFFSET_X]], %[[OFFSET_Y]], %[[VAL]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} |
| 88 | + // CHECK-COUNT-3: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} |
| 89 | + |
| 90 | + tt.store %0, %cst {ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #dpas>> |
46 | 91 | tt.return
|
47 | 92 | }
|
48 | 93 | }
|
|
0 commit comments