Skip to content

Commit 57f9c68

Browse files
authored
Add support for N-D vector extract in vector.extract_strided_slice (#698)
1 parent 8824668 commit 57f9c68

File tree

4 files changed

+292
-11
lines changed

4 files changed

+292
-11
lines changed

lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,29 +1369,86 @@ struct VectorExtractStridedSlice final
13691369
if (!dstType)
13701370
return failure();
13711371

1372-
// fixme : currently only support 1D vectors
1373-
if (extractOp.getSourceVectorType().getRank() != 1)
1374-
return failure();
1375-
1376-
uint64_t offset = getFirstIntValue(extractOp.getOffsets());
1377-
uint64_t size = getFirstIntValue(extractOp.getSizes());
1378-
uint64_t stride = getFirstIntValue(extractOp.getStrides());
1372+
auto offsets = extractOp.getOffsets().getValue();
1373+
auto sizes = extractOp.getSizes().getValue();
1374+
auto strides = extractOp.getStrides().getValue();
13791375

1380-
if (stride != 1)
1381-
return failure();
1376+
if (strides[0].cast<IntegerAttr>().getInt() != 1)
1377+
return rewriter.notifyMatchFailure(
1378+
extractOp, "Strided slice with stride != 1 is not supported.");
13821379

13831380
Value srcVector = adaptor.getOperands().front();
13841381

13851382
// Extract vector<1xT> case.
13861383
if (isa<spirv::ScalarType>(dstType)) {
1384+
uint64_t offset = getFirstIntValue(extractOp.getOffsets());
13871385
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
13881386
srcVector, offset);
13891387
return success();
13901388
}
13911389

1392-
SmallVector<int32_t, 2> indices(size);
1393-
std::iota(indices.begin(), indices.end(), offset);
1390+
// if kD offsets are specified for nd source vector (n > k), the granularity
1391+
// of the extraction is greater than 1. In this case last (n-k) dimensions
1392+
// form the extraction granularity. example : %0 =
1393+
// vector.extract_strided_slice %src { offsets = [0, 0], sizes = [2, 2],
1394+
// strides = [1, 1]} : vector<4x8x8xf32> to vector<2x2x8xf32>
1395+
// here, extraction granularity is 8.
1396+
int64_t extractSliceLen = 1;
1397+
auto n = extractOp.getSourceVectorType().getRank();
1398+
auto k = (int64_t)offsets.size();
1399+
if (n > k) {
1400+
for (unsigned i = 0; i < n - k; i++) {
1401+
extractSliceLen *= extractOp.getSourceVectorType().getShape()[i + k];
1402+
}
1403+
}
1404+
1405+
// get total number of extracted slices
1406+
int64_t nExtractedSlices = 1;
1407+
for (auto size : sizes) {
1408+
nExtractedSlices *= size.cast<IntegerAttr>().getInt();
1409+
}
13941410

1411+
// compute the strides of the source vector considering first k dimensions
1412+
SmallVector<int32_t, 4> sourceStrides(k, extractSliceLen);
1413+
for (int i = k - 2; i >= 0; --i) {
1414+
sourceStrides[i] = sourceStrides[i + 1] *
1415+
extractOp.getSourceVectorType().getShape()[i + 1];
1416+
}
1417+
// final shuffle indices has nExtractedElems * extractSliceLen elements
1418+
SmallVector<int32_t, 4> indices(nExtractedSlices * extractSliceLen);
1419+
// compute the strides of the extracted kD vector
1420+
SmallVector<int32_t, 4> extractedStrides(k, 1);
1421+
// compute extractedStrides
1422+
for (int i = k - 2; i >= 0; --i) {
1423+
extractedStrides[i] =
1424+
extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
1425+
}
1426+
// iterate over all extracted slices from 0 to nExtractedElems-1
1427+
// and compute the multi-dimensional index and the corresponding linearized
1428+
// index within the source vector
1429+
for (int64_t i = 0; i < nExtractedSlices; ++i) {
1430+
int64_t index = i;
1431+
// compute the corresponding multi-dimensional index
1432+
SmallVector<int32_t, 4> multiDimIndex(k, 0);
1433+
for (int64_t j = 0; j < k; ++j) {
1434+
multiDimIndex[j] = (index / extractedStrides[j]);
1435+
index -= multiDimIndex[j] * extractedStrides[j];
1436+
}
1437+
// compute the corresponding linearized index in the source vector
1438+
// i.e. shift the multiDimIndex by the offsets
1439+
int64_t linearizedIndex = 0;
1440+
for (int64_t j = 0; j < k; ++j) {
1441+
linearizedIndex +=
1442+
(offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
1443+
sourceStrides[j];
1444+
}
1445+
// fill the indices array form linearizedIndex to linearizedIndex +
1446+
// sliceLen
1447+
for (int64_t j = 0; j < extractSliceLen; ++j) {
1448+
indices[i * extractSliceLen + j] = linearizedIndex + j;
1449+
}
1450+
}
1451+
// perform a shuffle to extract the kD vector
13951452
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
13961453
extractOp, dstType, srcVector, srcVector,
13971454
rewriter.getI32ArrayAttr(indices));

test/Conversion/XeGPUToSPIRV/xegpu-to-vc.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,36 @@ gpu.module @test attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.
6767
gpu.return
6868
}
6969

70+
gpu.func @vector_extract_strided_slice(%src_1d : vector<128xf32>, %src_2d : vector<8x16xf32>, %src_nd : vector<2x32x8xf32>)
71+
kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
72+
// CHECK: spirv.VectorShuffle [32 : i32, 33 : i32, 34 : i32, 35 : i32, 36 : i32, 37 : i32, 38 : i32, 39 : i32,
73+
// CHECK: 40 : i32, 41 : i32, 42 : i32, 43 : i32, 44 : i32, 45 : i32, 46 : i32, 47 : i32]
74+
// CHECK: %[[vec_1d:.*]], %[[vec_1d]] : vector<128xf32>, vector<128xf32> -> vector<16xf32>
75+
%0 = vector.extract_strided_slice %src_1d {sizes = [16], strides = [1], offsets = [32]}
76+
: vector<128xf32> to vector<16xf32>
77+
78+
// CHECK: spirv.VectorShuffle [8 : i32, 9 : i32, 10 : i32, 11 : i32, 12 : i32, 13 : i32, 14 : i32, 15 : i32, 24 : i32,
79+
// CHECK: 25 : i32, 26 : i32, 27 : i32, 28 : i32, 29 : i32, 30 : i32, 31 : i32, 40 : i32, 41 : i32, 42 : i32,
80+
// CHECK: 43 : i32, 44 : i32, 45 : i32, 46 : i32, 47 : i32, 56 : i32, 57 : i32, 58 : i32, 59 : i32, 60 : i32,
81+
// CHECK: 61 : i32, 62 : i32, 63 : i32, 72 : i32, 73 : i32, 74 : i32, 75 : i32, 76 : i32, 77 : i32, 78 : i32,
82+
// CHECK: 79 : i32, 88 : i32, 89 : i32, 90 : i32, 91 : i32, 92 : i32, 93 : i32, 94 : i32, 95 : i32, 104 : i32,
83+
// CHECK: 105 : i32, 106 : i32, 107 : i32, 108 : i32, 109 : i32, 110 : i32, 111 : i32, 120 : i32, 121 : i32,
84+
// CHECK: 122 : i32, 123 : i32, 124 : i32, 125 : i32, 126 : i32, 127 : i32]
85+
// CHECK: %[[vec_2d:.*]], %[[vec_2d]] : vector<128xf32>, vector<128xf32> -> vector<64xf32>
86+
%1 = vector.extract_strided_slice %src_2d {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]}
87+
: vector<8x16xf32> to vector<8x8xf32>
88+
89+
// CHECK: spirv.VectorShuffle [192 : i32, 193 : i32, 194 : i32, 195 : i32, 196 : i32, 197 : i32, 198 : i32,
90+
// CHECK: 199 : i32, 200 : i32, 201 : i32, 202 : i32, 203 : i32, 204 : i32, 205 : i32, 206 : i32, 207 : i32,
91+
// CHECK: 208 : i32, 209 : i32, 210 : i32, 211 : i32, 212 : i32, 213 : i32, 214 : i32, 215 : i32, 216 : i32,
92+
// CHECK: 217 : i32, 218 : i32, 219 : i32, 220 : i32, 221 : i32, 222 : i32, 223 : i32, 224 : i32, 225 : i32,
93+
// CHECK: 226 : i32, 227 : i32, 228 : i32, 229 : i32, 230 : i32, 231 : i32, 232 : i32, 233 : i32, 234 : i32,
94+
// CHECK: 235 : i32, 236 : i32, 237 : i32, 238 : i32, 239 : i32, 240 : i32, 241 : i32, 242 : i32, 243 : i32,
95+
// CHECK: 244 : i32, 245 : i32, 246 : i32, 247 : i32, 248 : i32, 249 : i32, 250 : i32, 251 : i32, 252 : i32,
96+
// CHECK: 253 : i32, 254 : i32, 255 : i32] %[[vec_nd:.*]], %[[vec_nd]] : vector<512xf32>, vector<512xf32> -> vector<64xf32>
97+
%2 = vector.extract_strided_slice %src_nd { offsets = [0, 24], strides = [1, 1], sizes = [1, 8] } : vector<2x32x8xf32> to vector<1x8x8xf32>
98+
gpu.return
99+
100+
}
101+
70102
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
module @gemm attributes {gpu.container_module} {
10+
func.func @test(%A: memref<8x16xf16>, %B: memref<16x16xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
11+
%c1 = arith.constant 1 : index
12+
%memref = gpu.alloc host_shared () : memref<8x16xf16>
13+
%memref_1 = gpu.alloc host_shared () : memref<16x16xf16>
14+
memref.copy %A, %memref : memref<8x16xf16> to memref<8x16xf16>
15+
memref.copy %B, %memref_1 : memref<16x16xf16> to memref<16x16xf16>
16+
%memref_2 = gpu.alloc host_shared () : memref<8x16xf32>
17+
gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_1 : memref<16x16xf16>, %memref_2 : memref<8x16xf32>)
18+
gpu.dealloc %memref : memref<8x16xf16>
19+
gpu.dealloc %memref_1 : memref<16x16xf16>
20+
return %memref_2 : memref<8x16xf32>
21+
}
22+
23+
gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
24+
gpu.func @test_kernel(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
25+
%c0 = arith.constant 0 : index
26+
%c16 = arith.constant 16 : index
27+
// load A tile
28+
%a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] { mode = vc } : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
29+
%val0 = xegpu.load_nd %a_tile0 { mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16>
30+
// load B tile
31+
%b_tile0 = xegpu.create_nd_tdesc %B [%c0, %c0] { mode = vc } : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
32+
%val2 = xegpu.load_nd %b_tile0 { mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
33+
// do DPAS
34+
%val4 = xegpu.dpas %val0, %val2 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
35+
// extract second 8x8
36+
%val5 = vector.extract_strided_slice %val4 {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} : vector<8x16xf32> to vector<8x8xf32>
37+
%cst_8x8_flat = arith.constant dense<1.0> : vector<64xf32>
38+
%cst_8x8 = vector.shape_cast %cst_8x8_flat : vector<64xf32> to vector<8x8xf32>
39+
// shift the first half to left and use %cst_8x8 as the second half
40+
%val6 = vector.shuffle %val5, %cst_8x8 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32>
41+
%val7 = vector.shape_cast %val6 : vector<16x8xf32> to vector<8x16xf32>
42+
// store
43+
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] { mode = vc } : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
44+
xegpu.store_nd %val7, %out_tile { mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
45+
gpu.return
46+
}
47+
}
48+
func.func @main() attributes {llvm.emit_c_interface} {
49+
// init constants
50+
%c0 = arith.constant 0 : index
51+
%c1 = arith.constant 1 : index
52+
%c8 = arith.constant 8 : index
53+
%c16 = arith.constant 16 : index
54+
%c1_f32 = arith.constant 1.0 : f32
55+
// random init
56+
%lower = arith.constant -1.0 : f32
57+
%upper = arith.constant 1.0 : f32
58+
%false = arith.constant 0 : i1
59+
%A = memref.alloc() : memref<8x16xf16>
60+
%B = memref.alloc() : memref<16x16xf16>
61+
%Out_cpu = memref.alloc() : memref<8x16xf32>
62+
%A_random = memref.cast %A : memref<8x16xf16> to memref<*xf16>
63+
%B_random = memref.cast %B : memref<16x16xf16> to memref<*xf16>
64+
call @fillResource1DRandomF16(%A_random, %lower, %upper, %false) : (memref<*xf16>, f32, f32, i1) -> ()
65+
call @fillResource1DRandomF16(%B_random, %lower, %upper, %false) : (memref<*xf16>, f32, f32, i1) -> ()
66+
// run GPU version
67+
%Out_gpu = call @test(%A, %B) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32>
68+
%Out_gpu_cast = memref.cast %Out_gpu : memref<8x16xf32> to memref<*xf32>
69+
// run CPU version
70+
scf.for %i = %c0 to %c8 step %c1 {
71+
scf.for %j = %c8 to %c16 step %c1 {
72+
%v0_init = arith.constant 0.0 : f32
73+
%result:1 = scf.for %k = %c0 to %c16 step %c1 iter_args(%v0 = %v0_init) -> f32 {
74+
%a0 = memref.load %A[%i, %k] : memref<8x16xf16>
75+
%b0 = memref.load %B[%k, %j] : memref<16x16xf16>
76+
%a0_f32 = arith.extf %a0 : f16 to f32
77+
%b0_f32 = arith.extf %b0 : f16 to f32
78+
%t0 = arith.mulf %a0_f32, %b0_f32 : f32
79+
%v0_new = arith.addf %v0, %t0 : f32
80+
scf.yield %v0_new : f32
81+
}
82+
// only update the first 8x8 of the result, next 8x8 is value 1
83+
%shifted_j = arith.subi %j, %c8 : index
84+
memref.store %result#0, %Out_cpu[%i, %shifted_j] : memref<8x16xf32>
85+
memref.store %c1_f32, %Out_cpu[%i, %j] : memref<8x16xf32>
86+
}
87+
}
88+
%Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32>
89+
// print GPU and CPU outs
90+
// call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> ()
91+
// call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> ()
92+
// CHECK: [ALLCLOSE: TRUE]
93+
call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> ()
94+
// dealloc
95+
memref.dealloc %A : memref<8x16xf16>
96+
memref.dealloc %B : memref<16x16xf16>
97+
memref.dealloc %Out_cpu : memref<8x16xf32>
98+
// gpu dealloc
99+
gpu.dealloc %Out_gpu : memref<8x16xf32>
100+
return
101+
}
102+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
103+
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
104+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
105+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
module @gemm attributes {gpu.container_module} {
10+
func.func @test(%A: memref<32x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
11+
%c1 = arith.constant 1 : index
12+
%memref = gpu.alloc host_shared () : memref<32x16xf32>
13+
memref.copy %A, %memref : memref<32x16xf32> to memref<32x16xf32>
14+
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
15+
gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x16xf32>, %memref_1 : memref<8x16xf32>)
16+
gpu.dealloc %memref : memref<32x16xf32>
17+
return %memref_1 : memref<8x16xf32>
18+
}
19+
20+
gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
21+
gpu.func @test_kernel(%A: memref<32x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
22+
%c0 = arith.constant 0 : index
23+
%c16 = arith.constant 16 : index
24+
// load tile
25+
%tile = xegpu.create_nd_tdesc %A [%c0, %c0] {mode = vc} : memref<32x16xf32> -> !xegpu.tensor_desc<32x8xf32, #xegpu.tdesc_attr<array_length = 2>>
26+
%value = xegpu.load_nd %tile {mode = vc} : !xegpu.tensor_desc<32x8xf32, #xegpu.tdesc_attr<array_length = 2>> -> vector<2x32x8xf32>
27+
// extract the bottom 8x8 part of first 32x8 block
28+
%sub_tile0 = vector.extract_strided_slice %value { offsets = [0, 24], strides = [1, 1], sizes = [1, 8] } : vector<2x32x8xf32> to vector<1x8x8xf32>
29+
// extract the bottom 8x8 part of second 32x8 block
30+
%sub_tile1 = vector.extract_strided_slice %value { offsets = [1, 24], strides = [1, 1], sizes = [1, 8] } : vector<2x32x8xf32> to vector<1x8x8xf32>
31+
// combine these two 8x8 tiles into a single 8x16 tile
32+
%t1 = vector.shape_cast %sub_tile0 : vector<1x8x8xf32> to vector<8x8xf32>
33+
%t2 = vector.shape_cast %sub_tile1 : vector<1x8x8xf32> to vector<8x8xf32>
34+
%t3 = vector.shuffle %t1, %t2 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32>
35+
%t4 = vector.shape_cast %t3 : vector<16x8xf32> to vector<8x16xf32>
36+
37+
// store the result
38+
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
39+
xegpu.store_nd %t4, %out_tile {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
40+
gpu.return
41+
}
42+
}
43+
func.func @main() attributes {llvm.emit_c_interface} {
44+
// init constants
45+
%c0 = arith.constant 0 : index
46+
%c1 = arith.constant 1 : index
47+
%c32 = arith.constant 32 : index
48+
%c16 = arith.constant 16 : index
49+
%c24 = arith.constant 24 : index
50+
%c1_f32 = arith.constant 1.0 : f32
51+
%A = memref.alloc() : memref<32x16xf32>
52+
%Out_cpu = memref.alloc() : memref<8x16xf32>
53+
// fill A with values form 0, 1, ...., 511
54+
scf.for %i = %c0 to %c32 step %c1 {
55+
scf.for %j = %c0 to %c16 step %c1 {
56+
%t1 = arith.muli %i, %c16 : index
57+
%val = arith.addi %t1, %j : index
58+
%val_i32 = arith.index_cast %val : index to i32
59+
%val_f32 = arith.sitofp %val_i32 : i32 to f32
60+
%cond = arith.cmpi "sge", %i, %c24 : index
61+
// only store the bottom 8x16 into Out_cpu
62+
scf.if %cond {
63+
%i_cpu = arith.subi %i, %c24 : index
64+
memref.store %val_f32, %Out_cpu[%i_cpu, %j] : memref<8x16xf32>
65+
}
66+
memref.store %val_f32, %A[%i, %j] : memref<32x16xf32>
67+
}
68+
}
69+
// run GPU version
70+
%Out_gpu = call @test(%A) : (memref<32x16xf32>) -> memref<8x16xf32>
71+
%Out_gpu_cast = memref.cast %Out_gpu : memref<8x16xf32> to memref<*xf32>
72+
%A_cast = memref.cast %A : memref<32x16xf32> to memref<*xf32>
73+
%Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32>
74+
// print GPU and CPU outs
75+
// call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> ()
76+
// call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> ()
77+
// CHECK: [ALLCLOSE: TRUE]
78+
call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> ()
79+
// dealloc
80+
memref.dealloc %A : memref<32x16xf32>
81+
// gpu dealloc
82+
gpu.dealloc %Out_gpu : memref<8x16xf32>
83+
return
84+
}
85+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
86+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
87+
}

0 commit comments

Comments
 (0)