|
1 | 1 | // RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s |
2 | 2 |
|
3 | 3 | gpu.module @load_store_check { |
| 4 | + // CHECK-LABEL: gpu.func @load_store |
| 5 | + // CHECK-SAME: %[[ARG0:.*]]: ui64, %[[ARG1:.*]]: ui32) kernel { |
4 | 6 | gpu.func @load_store(%src: ui64, %dst: ui32) kernel { |
5 | | - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 |
6 | | - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> |
7 | | - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> |
8 | | - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> |
9 | | - // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> |
10 | | - // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> |
11 | | - // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> |
12 | | - // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> |
| 7 | + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 |
| 8 | + // CHECK: %[[C0_I32:.*]] = arith.constant 0 |
| 9 | + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 |
| 10 | + // CHECK: %[[ARG1_IDX:.*]] = index.castu %[[ARG1]] : ui32 to index |
| 11 | + // CHECK: %[[ARG1_I32:.*]] = arith.index_castui %[[ARG1_IDX]] : index to i32 |
| 12 | + // CHECK: %[[ARG0_IDX:.*]] = index.castu %[[ARG0]] : ui64 to index |
| 13 | + // CHECK: %[[ARG0_I64:.*]] = arith.index_castui %[[ARG0_IDX]] : index to i64 |
13 | 14 | %c8 = arith.constant 8 : index |
14 | 15 | %c16 = arith.constant 16 : index |
15 | 16 | %c1 = arith.constant 1 : index |
16 | 17 | %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%c8, %c16], strides:[%c16, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> |
17 | 18 |
|
18 | 19 |
|
19 | | - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> |
20 | | - //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> |
21 | | - //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> |
22 | | - //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> |
23 | | - //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64 |
24 | | - //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32 |
25 | | - //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 |
26 | | - //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 |
27 | | - //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> |
28 | | - //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 |
29 | | - //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 |
30 | | - //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], |
31 | | - //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] |
32 | | - //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32, |
33 | | - //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, |
34 | | - //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> |
| 20 | + // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[ARG0_I64]] : i64 to !llvm.ptr<1> |
| 21 | + // CHECK: %[[LOAD:.*]] = xevm.blockload2d %[[VAR4]], %[[C64_I32]], %[[C8_I32]], %[[C64_I32]], |
| 22 | + // CHECK-SAME: %[[C0_I32]], %[[C0_I32]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, |
| 23 | + // CHECK-SAME: elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32, |
| 24 | + // CHECK-SAME: tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}> |
| 25 | + // CHECK: %[[VAR6:.*]] = vector.bitcast %[[LOAD]] : vector<8xi32> to vector<8xf32> |
35 | 26 | %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> |
36 | 27 | : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> |
37 | | - //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32> |
38 | 28 |
|
39 | 29 | %tid_x = gpu.thread_id x |
40 | 30 | %tid_x_i32 = arith.index_cast %tid_x : index to i32 |
41 | 31 | %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 |
42 | | - //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> |
| 32 | + // CHECK: %[[VAR9:.*]] = vector.insert |
43 | 33 | %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> |
44 | 34 |
|
45 | | - // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 |
46 | | - // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> |
47 | | - // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> |
48 | | - // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> |
49 | | - // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> |
50 | | - // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32> |
51 | | - // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32> |
52 | | - // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> |
| 35 | + // CHECK: %[[VAR10:.*]] = arith.extui %[[ARG1_I32]] : i32 to i64 |
53 | 36 | %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%c8, %c16], strides:[%c16, %c1] : ui32 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> |
54 | 37 |
|
55 | | - //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> |
56 | | - //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> |
57 | | - //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> |
58 | | - //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> |
59 | | - //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64 |
60 | | - //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32 |
61 | | - //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 |
62 | | - //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 |
63 | | - //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> |
64 | | - //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 |
65 | | - //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 |
66 | | - //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> |
67 | | - //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], |
68 | | - //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] |
69 | | - //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32, |
70 | | - //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) |
| 38 | + // CHECK: %[[VAR11:.*]] = llvm.inttoptr %[[VAR10]] : i64 to !llvm.ptr<1> |
| 39 | + // CHECK: %[[STORE:.*]] = vector.bitcast %[[VAR9]] : vector<8xf32> to vector<8xi32> |
| 40 | + // CHECK: xevm.blockstore2d %[[VAR11]], %[[C64_I32]], %[[C8_I32]], %[[C64_I32]], %[[C0_I32]], %[[C0_I32]], %[[STORE]] |
| 41 | + // CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, |
| 42 | + // CHECK-SAME: elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}> |
71 | 43 | xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> |
72 | 44 | : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> |
73 | 45 | gpu.return |
|
0 commit comments