|
1 | 1 | // RUN: mlir-opt --xegpu-blocking -split-input-file %s | FileCheck %s |
2 | 2 |
|
3 | | - |
4 | 3 | #a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]> |
5 | 4 | #b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]> |
6 | 5 | #c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]> |
7 | | - |
8 | | -#l1 = #xegpu.layout<inst_data = [8, 16]> |
9 | | -#l2 = #xegpu.layout<inst_data = [16, 16]> |
10 | | - |
11 | 6 | gpu.module @test_kernel { |
12 | 7 | gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { |
13 | 8 | %c0 = arith.constant 0 : index |
@@ -44,9 +39,13 @@ gpu.module @test_kernel { |
44 | 39 | xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c> |
45 | 40 | gpu.return |
46 | 41 | } |
| 42 | +} |
47 | 43 |
|
48 | | - //----- |
49 | | - gpu.func @test_gemm_simple(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { |
| 44 | +// ----- |
| 45 | +#l1 = #xegpu.layout<inst_data = [8, 16]> |
| 46 | +#l2 = #xegpu.layout<inst_data = [16, 16]> |
| 47 | +gpu.module @test_kernel { |
| 48 | + gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { |
50 | 49 | %c0 = arith.constant 0 : index |
51 | 50 | %c16 = arith.constant 16 : index |
52 | 51 | %c32 = arith.constant 32 : index |
@@ -81,10 +80,14 @@ gpu.module @test_kernel { |
81 | 80 | xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #l1> |
82 | 81 | gpu.return |
83 | 82 | } |
| 83 | +} |
84 | 84 |
|
85 | | - //----- |
86 | | - |
87 | | - gpu.func @test_gemm_a_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { |
| 85 | +// ----- |
| 86 | +#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]> |
| 87 | +#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]> |
| 88 | +#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]> |
| 89 | +gpu.module @test_kernel { |
| 90 | + gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { |
88 | 91 | %c0 = arith.constant 0 : index |
89 | 92 | %c16 = arith.constant 16 : index |
90 | 93 | %c32 = arith.constant 32 : index |
@@ -120,4 +123,83 @@ gpu.module @test_kernel { |
120 | 123 | //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>> |
121 | 124 | xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c> |
122 | 125 | gpu.return |
123 | | - }} |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +// ----- |
| 130 | +#l = #xegpu.layout<inst_data = [8, 16]> |
| 131 | +gpu.module @test_kernel { |
| 132 | + gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) { |
| 133 | + %c0 = arith.constant 0 : index |
| 134 | + %c32 = arith.constant 32 : index |
| 135 | + %c1024 = arith.constant 1024 : index |
| 136 | + %block_id_x = gpu.block_id x |
| 137 | + %block_id_y = gpu.block_id y |
| 138 | + %m = arith.muli %block_id_x, %c32 : index |
| 139 | + |
| 140 | + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l> |
| 141 | + %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l> |
| 142 | + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l> |
| 143 | + |
| 144 | + %out:3 = scf.for %k = %c0 to %c1024 step %c32 |
| 145 | + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) |
| 146 | + -> (!xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>) { |
| 147 | + //CHECK-COUNT-8: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> |
| 148 | + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16> |
| 149 | + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16> |
| 150 | + |
| 151 | + //CHECK-COUNT-4: arith.addf {{.*}} : vector<8x16xf16> |
| 152 | + %c = arith.addf %a, %b {layout_result_0 = #l} : vector<16x32xf16> |
| 153 | + |
| 154 | + //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> |
| 155 | + xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #l> |
| 156 | + |
| 157 | + //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16> |
| 158 | + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l> |
| 159 | + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l> |
| 160 | + %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l> |
| 161 | + scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc |
| 162 | + : !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l> |
| 163 | + } |
| 164 | + gpu.return |
| 165 | + } |
| 166 | +} |
| 167 | + |
| 168 | +// ----- |
| 169 | +#l = #xegpu.layout<inst_data = [8]> |
| 170 | +gpu.module @test_kernel { |
| 171 | + gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) { |
| 172 | + %c0 = arith.constant 0 : index |
| 173 | + %c32 = arith.constant 32 : index |
| 174 | + %c1024 = arith.constant 1024 : index |
| 175 | + %block_id_x = gpu.block_id x |
| 176 | + %block_id_y = gpu.block_id y |
| 177 | + %m = arith.muli %block_id_x, %c32 : index |
| 178 | + |
| 179 | + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l> |
| 180 | + %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l> |
| 181 | + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l> |
| 182 | + |
| 183 | + %out:3 = scf.for %k = %c0 to %c1024 step %c32 |
| 184 | + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) |
| 185 | + -> (!xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>) { |
| 186 | + //CHECK-COUNT-8: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8xf16> -> vector<8xf16> |
| 187 | + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16> |
| 188 | + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16> |
| 189 | + |
| 190 | + //CHECK-COUNT-4: arith.addf {{.*}} : vector<8xf16> |
| 191 | + %c = arith.addf %a, %b {layout_result_0 = #l} : vector<32xf16> |
| 192 | + |
| 193 | + //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8xf16>, !xegpu.tensor_desc<8xf16> |
| 194 | + xegpu.store_nd %c, %arg2: vector<32xf16>, !xegpu.tensor_desc<32xf16, #l> |
| 195 | + |
| 196 | + //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8xf16> |
| 197 | + %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c32] : !xegpu.tensor_desc<32xf16, #l> |
| 198 | + %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32] : !xegpu.tensor_desc<32xf16, #l> |
| 199 | + %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c32] : !xegpu.tensor_desc<32xf16, #l> |
| 200 | + scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc |
| 201 | + : !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l> |
| 202 | + } |
| 203 | + gpu.return |
| 204 | + } |
| 205 | +} |
0 commit comments