|
1 | | -// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s |
| 1 | +// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s |
2 | 2 |
|
3 | 3 | gpu.module @test_kernel [#xevm.target<chip = "pvc">] { |
4 | 4 |
|
5 | | - // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> |
| 5 | + // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> |
6 | 6 | // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) |
7 | 7 | //CHECK-LABEL: load_store_matrix_1 |
8 | 8 | gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { |
9 | 9 | %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> |
| 10 | + |
| 11 | + //CHECK: %[[TID:.*]] = gpu.thread_id x |
| 12 | + //CHECK: %[[C1:.*]] = arith.constant 1 : index |
| 13 | + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index |
| 14 | + //CHECK: %[[C4:.*]] = arith.constant 4 : i64 |
| 15 | + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64 |
10 | 16 | //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 |
| 17 | + |
11 | 18 | %tid_x = gpu.thread_id x |
12 | 19 | %c0 = arith.constant 0 : index |
13 | 20 | %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 |
| 21 | + |
| 22 | + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> |
| 23 | + |
| 24 | + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index |
| 25 | + |
14 | 26 | gpu.return %1: f32 |
15 | 27 | } |
16 | 28 |
|
17 | | - // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> |
| 29 | +// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> |
18 | 30 | // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) |
19 | 31 | //CHECK-LABEL: load_store_matrix_2 |
20 | 32 | gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { |
21 | 33 | %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> |
22 | | - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16 |
| 34 | + //CHECK: %[[c0:.*]] = arith.constant 0 : index |
| 35 | + //CHECK: %[[tid_x:.*]] = gpu.thread_id x |
| 36 | + //CHECK: %[[c13:.*]] = arith.constant 13 : index |
| 37 | + //CHECK: %[[c16:.*]] = arith.constant 16 : index |
| 38 | + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index |
| 39 | + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index |
| 40 | + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index |
| 41 | + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index |
| 42 | + |
| 43 | + //CHECK: %[[c256:.*]] = arith.constant 256 : index |
| 44 | + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index |
| 45 | + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index |
| 46 | + //CHECK: %[[c512:.*]] = arith.constant 512 : index |
| 47 | + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index |
| 48 | + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index |
| 49 | + //CHECK: %[[c1:.*]] = arith.constant 1 : index |
| 50 | + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index |
| 51 | + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index |
| 52 | + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index |
| 53 | + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index |
| 54 | + |
| 55 | + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 |
| 56 | + |
| 57 | + |
23 | 58 | %tid_x = gpu.thread_id x |
24 | 59 | %c13 = arith.constant 13 : index |
25 | 60 | %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16 |
| 61 | + |
| 62 | + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> |
| 63 | + |
| 64 | + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index |
26 | 65 | gpu.return %1: f16 |
27 | 66 | } |
28 | 67 |
|
| 68 | + |
29 | 69 | // e.g. for mem_desc<32x64xf16, @block=[16, 16]> |
30 | 70 | // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) |
31 | 71 | //CHECK-LABEL: load_store_matrix_3 |
32 | 72 | gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { |
| 73 | + //CHECK: %[[c0:.*]] = arith.constant 0 : index |
| 74 | + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> |
33 | 75 | %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> |
34 | | - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16 |
| 76 | + |
| 77 | + //CHECK: %[[tid_x:.*]] = gpu.thread_id x |
| 78 | + //CHECK: %[[c19:.*]] = arith.constant 19 : index |
35 | 79 | %tid_x = gpu.thread_id x |
36 | 80 | %c19 = arith.constant 19: index |
| 81 | + |
| 82 | + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index |
| 83 | + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 |
| 84 | + //CHECK: %[[c16:.*]] = arith.constant 16 : index |
| 85 | + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index |
| 86 | + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index |
| 87 | + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index |
| 88 | + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index |
| 89 | + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index |
| 90 | + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index |
| 91 | + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index |
| 92 | + //CHECK: %[[c256:.*]] = arith.constant 256 : index |
| 93 | + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index |
| 94 | + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index |
| 95 | + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index |
| 96 | + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index |
| 97 | + //CHECK: %[[c1:.*]] = arith.constant 1 : index |
| 98 | + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index |
| 99 | + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index |
| 100 | + |
| 101 | + //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 |
37 | 102 | %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16 |
| 103 | + |
| 104 | + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> |
| 105 | + xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index |
| 106 | + |
| 107 | + //CHECK: gpu.return %[[loaded]] : f16 |
38 | 108 | gpu.return %1: f16 |
39 | 109 | } |
40 | | - |
41 | | - // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> |
| 110 | + |
| 111 | + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> |
42 | 112 | // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) |
43 | 113 | //CHECK-LABEL: load_store_matrix_4 |
44 | 114 | gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { |
45 | 115 | %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> |
46 | | - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> |
| 116 | + |
| 117 | + //CHECK: %[[c0:.*]] = arith.constant 0 : index |
| 118 | + //CHECK: %[[tid_x:.*]] = gpu.thread_id x |
| 119 | + |
| 120 | + //CHECK: %[[c16:.*]] = arith.constant 16 : index |
| 121 | + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index |
| 122 | + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index |
| 123 | + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index |
| 124 | + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index |
| 125 | + |
| 126 | + //CHECK: %[[c256:.*]] = arith.constant 256 : index |
| 127 | + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index |
| 128 | + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index |
| 129 | + //CHECK: %[[c512:.*]] = arith.constant 512 : index |
| 130 | + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index |
| 131 | + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index |
| 132 | + //CHECK: %[[c1:.*]] = arith.constant 1 : index |
| 133 | + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index |
| 134 | + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index |
| 135 | + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index |
| 136 | + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index |
| 137 | + |
| 138 | + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> |
| 139 | + |
47 | 140 | %tid_x = gpu.thread_id x |
48 | 141 | %c16 = arith.constant 16 : index |
49 | 142 | %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16> |
| 143 | + |
| 144 | + //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> |
| 145 | + xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index |
| 146 | + |
50 | 147 | gpu.return %1: vector<8xf16> |
51 | 148 | } |
52 | 149 |
|
| 150 | + |
53 | 151 | // e.g. for mem_desc<32x64xf16, @block=[16, 16]> |
54 | 152 | // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) |
55 | 153 | //CHECK-LABEL: load_store_matrix_5 |
56 | 154 | gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { |
| 155 | + //CHECK: %[[c0:.*]] = arith.constant 0 : index |
| 156 | + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> |
| 157 | + |
57 | 158 | %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> |
58 | | - //CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16> |
| 159 | + |
| 160 | + //CHECK: %[[c16:.*]] = arith.constant 16 : index |
| 161 | + //CHECK: %[[c48:.*]] = arith.constant 48 : index |
| 162 | + |
59 | 163 | %c16 = arith.constant 16 : index |
60 | 164 | %c48 = arith.constant 48 : index |
| 165 | + |
| 166 | + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index |
| 167 | + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 |
| 168 | + //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index |
| 169 | + //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index |
| 170 | + //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index |
| 171 | + //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index |
| 172 | + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index |
| 173 | + //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index |
| 174 | + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index |
| 175 | + //CHECK: %[[c256:.*]] = arith.constant 256 : index |
| 176 | + //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index |
| 177 | + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index |
| 178 | + //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index |
| 179 | + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index |
| 180 | + //CHECK: %[[c1:.*]] = arith.constant 1 : index |
| 181 | + //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index |
| 182 | + //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index |
| 183 | + //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64 |
| 184 | + //CHECK: %[[c2:.*]] = arith.constant 2 : i64 |
| 185 | + //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64 |
| 186 | + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64 |
| 187 | + //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3> |
| 188 | + //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> |
| 189 | + //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> |
| 190 | + |
61 | 191 | %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16> |
| 192 | + |
| 193 | + //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> |
| 194 | + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) |
| 195 | + |
| 196 | + xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index |
| 197 | + |
62 | 198 | gpu.return %1: vector<8xf16> |
63 | 199 | } |
64 | 200 |
|
|
0 commit comments