|
| 1 | +// RUN: %python_executable %imex_runner --requires=l0-runtime,spirv-backend -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ |
| 2 | +// RUN: --runner imex-cpu-runner -e main \ |
| 3 | +// RUN: --entry-point-result=void \ |
| 4 | +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck |
| 5 | + |
| 6 | + |
| 7 | +module @gemm attributes {gpu.container_module} { |
| 8 | + gpu.module @kernel { |
| 9 | + gpu.func @load_store_2d_dpas(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel { |
| 10 | + %c0 = arith.constant 0 : index |
| 11 | + %c1 = arith.constant 1 : index |
| 12 | + %c8 = arith.constant 8 : index |
| 13 | + %c16 = arith.constant 16 : index |
| 14 | + %c32 = arith.constant 32 : index |
| 15 | + %c256 = arith.constant 256 : index |
| 16 | + %block_x = gpu.block_id x |
| 17 | + %block_y = gpu.block_id y |
| 18 | + %x_block_offset = arith.muli %block_x, %c8 : index |
| 19 | + %y_block_offset = arith.muli %block_y, %c16 : index |
| 20 | + |
| 21 | + %c_tdesc = xegpu.create_nd_tdesc %c[%x_block_offset, %y_block_offset] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> |
| 22 | + %c_init_value = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> -> vector<8xf32> |
| 23 | + |
| 24 | + %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> ( vector<8xf32>) { |
| 25 | + // TODO: There is issue with update_nd_offset. To avoid it, we use create_nd_tdesc here. |
| 26 | + %a_tdesc_new = xegpu.create_nd_tdesc %a[%x_block_offset, %k] : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_space = global>> |
| 27 | + %b_tdesc_new = xegpu.create_nd_tdesc %b[%k, %y_block_offset] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global>> |
| 28 | + %a_val = xegpu.load_nd %a_tdesc_new : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_space = global>> -> vector<8xf16> |
| 29 | + %b_val = xegpu.load_nd %b_tdesc_new : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global>> -> vector<16xf16> |
| 30 | + %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> |
| 31 | + scf.yield %dpas : vector<8xf32> |
| 32 | + } |
| 33 | + xegpu.store_nd %r, %c_tdesc <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> |
| 34 | + gpu.return |
| 35 | + } |
| 36 | + } |
| 37 | + |
| 38 | + func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { |
| 39 | + %c1 = arith.constant 1 : index |
| 40 | + %c16 = arith.constant 16 : index |
| 41 | + %c32 = arith.constant 32 : index |
| 42 | + %memref_a = gpu.alloc host_shared () : memref<256x256xf16> |
| 43 | + memref.copy %a, %memref_a : memref<256x256xf16> to memref<256x256xf16> |
| 44 | + %memref_b = gpu.alloc host_shared () : memref<256x256xf16> |
| 45 | + memref.copy %b, %memref_b : memref<256x256xf16> to memref<256x256xf16> |
| 46 | + %memref_c = gpu.alloc host_shared () : memref<256x256xf32> |
| 47 | + memref.copy %c, %memref_c : memref<256x256xf32> to memref<256x256xf32> |
| 48 | + |
| 49 | + gpu.launch_func @kernel::@load_store_2d_dpas blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>) |
| 50 | + return %memref_c : memref<256x256xf32> |
| 51 | + } |
| 52 | + |
| 53 | + // compute CPU reference (takes minutes) |
| 54 | + func.func @cpu_reference(%A : memref<256x256xf16>, %B : memref<256x256xf16>, %C : memref<256x256xf32>) { |
| 55 | + %c256 = arith.constant 256 : index |
| 56 | + %c16 = arith.constant 16 : index |
| 57 | + %c1 = arith.constant 1 : index |
| 58 | + %c0 = arith.constant 0 : index |
| 59 | + scf.for %i = %c0 to %c256 step %c1 { |
| 60 | + scf.for %j = %c0 to %c256 step %c1 { |
| 61 | + %c_curr = memref.load %C[%i, %j] : memref<256x256xf32> |
| 62 | + %c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 { |
| 63 | + %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { |
| 64 | + %k_dpas = arith.addi %k_tile, %k : index |
| 65 | + %a_val = memref.load %A[%i, %k_dpas] : memref<256x256xf16> |
| 66 | + %b_val = memref.load %B[%k_dpas, %j] : memref<256x256xf16> |
| 67 | + %a_cast = arith.extf %a_val : f16 to f32 |
| 68 | + %b_cast = arith.extf %b_val : f16 to f32 |
| 69 | + %t = arith.mulf %a_cast, %b_cast : f32 |
| 70 | + // %t_cast = arith.extf %t : f16 to f16 |
| 71 | + %c_sum = arith.addf %t, %c_dpas_partial : f32 |
| 72 | + scf.yield %c_sum : f32 |
| 73 | + } |
| 74 | + scf.yield %c_val_dpas : f32 |
| 75 | + } |
| 76 | + // %c_val_f16 = arith.truncf %c_val : f32 to f16 |
| 77 | + // %c_val_ = arith.extf %c_val_f16 : f16 to f32 |
| 78 | + memref.store %c_val , %C[%i, %j] : memref<256x256xf32> |
| 79 | + } |
| 80 | + } |
| 81 | + return |
| 82 | + } |
| 83 | + |
| 84 | + |
| 85 | + func.func @main() attributes {llvm.emit_c_interface} { |
| 86 | + %c0 = arith.constant 0 : index |
| 87 | + %c1 = arith.constant 1 : index |
| 88 | + %c1_f16 = arith.constant 1.0 : f16 |
| 89 | + %c2_f16 = arith.constant 2.0 : f16 |
| 90 | + %c256 = arith.constant 256 : index |
| 91 | + %cf_0 = arith.constant 0.0 : f16 |
| 92 | + %cf_1 = arith.constant 1.0 : f16 |
| 93 | + %A = memref.alloc() : memref<256x256xf16> |
| 94 | + %B = memref.alloc() : memref<256x256xf16> |
| 95 | + %C = memref.alloc() : memref<256x256xf32> |
| 96 | + %C_ref = memref.alloc() : memref<256x256xf32> |
| 97 | + %c_gen_int = arith.constant 0 : i1 |
| 98 | + %cf_lower = arith.constant -0.5 : f32 |
| 99 | + %cf_upper = arith.constant 0.5 : f32 |
| 100 | + // Use one of the two options to initialize the A matrix |
| 101 | + // Option 1: intialize matrix A ; A[i, j] = j |
| 102 | + // scf.for %i = %c0 to %c256 step %c1 { |
| 103 | + // scf.for %j = %c0 to %c256 step %c1 { |
| 104 | + // %t = index.castu %j : index to i16 |
| 105 | + // %val = arith.uitofp %t : i16 to f16 |
| 106 | + // memref.store %val, %A[%i, %j] : memref<256x256xf16> |
| 107 | + // // memref.store %c1_f16, %A[%i, %j] : memref<256x256xf16> |
| 108 | + // // memref.store %c2_f16, %B[%i, %j] : memref<256x256xf16> |
| 109 | + // } |
| 110 | + // } |
| 111 | + // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) |
| 112 | + %A_random = memref.cast %A : memref<256x256xf16> to memref<*xf16> |
| 113 | + call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () |
| 114 | + |
| 115 | + |
| 116 | + // Use one of the two options below to initialize the B matrix |
| 117 | + // Option 1: make matrix B an identity matrix |
| 118 | + // scf.for %i = %c0 to %c256 step %c1 { |
| 119 | + // scf.for %j = %c0 to %c256 step %c1 { |
| 120 | + // %i_i32 = index.castu %i : index to i32 |
| 121 | + // %j_i32 = index.castu %j : index to i32 |
| 122 | + // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 |
| 123 | + |
| 124 | + // scf.if %i_j_same { |
| 125 | + // memref.store %cf_1, %B[%i, %j] : memref<256x256xf16> |
| 126 | + // } else { |
| 127 | + // memref.store %cf_0, %B[%i, %j] : memref<256x256xf16> |
| 128 | + // } |
| 129 | + // } |
| 130 | + // } |
| 131 | + // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) |
| 132 | + %B_random = memref.cast %B : memref<256x256xf16> to memref<*xf16> |
| 133 | + call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () |
| 134 | + |
| 135 | + |
| 136 | + // intialize matrix C and C_ref ; C[i, j] = 0 |
| 137 | + %c0_f32 = arith.constant 0.0 : f32 |
| 138 | + scf.for %i = %c0 to %c256 step %c1 { |
| 139 | + scf.for %j = %c0 to %c256 step %c1 { |
| 140 | + memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> |
| 141 | + memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> |
| 142 | + } |
| 143 | + } |
| 144 | + // print input fror debug |
| 145 | + // %A_row_0 = memref.subview %A[1, 0][1, 256][1, 1] : memref<256x256xf16> to memref<1x256xf16, strided<[256, 1], offset: 256>> |
| 146 | + // %A_row_0_cast = memref.cast %A_row_0 : memref<1x256xf16, strided<[256, 1], offset: 256>> to memref<*xf16> |
| 147 | + // call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> () |
| 148 | + |
| 149 | + // run GPU |
| 150 | + %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> |
| 151 | + |
| 152 | + call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> () |
| 153 | + |
| 154 | + // %cast = memref.cast %A : memref<256x256xf16> to memref<*xf16> |
| 155 | + // call @printMemrefF16(%cast) : (memref<*xf16>) -> () |
| 156 | + %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> |
| 157 | + %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> |
| 158 | + // call @printMemrefF16(%cast_C) : (memref<*xf16>) -> () |
| 159 | + // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () |
| 160 | + |
| 161 | + %C_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>> |
| 162 | + %C_row_0_cast = memref.cast %C_row_0 : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32> |
| 163 | + // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () |
| 164 | + |
| 165 | + %C_row_0_gpu = memref.subview %2[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>> |
| 166 | + %C_row_0_cast_gpu = memref.cast %C_row_0_gpu : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32> |
| 167 | + // call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> () |
| 168 | + |
| 169 | + // CHECK: [ALLCLOSE: TRUE] |
| 170 | + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () |
| 171 | + memref.dealloc %A : memref<256x256xf16> |
| 172 | + memref.dealloc %B : memref<256x256xf16> |
| 173 | + memref.dealloc %C : memref<256x256xf32> |
| 174 | + memref.dealloc %C_ref : memref<256x256xf32> |
| 175 | + return |
| 176 | + } |
| 177 | + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} |
| 178 | + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} |
| 179 | + func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} |
| 180 | + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} |
| 181 | + func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} |
| 182 | +} |
0 commit comments