Skip to content

Commit 9fd715c

Browse files
Garra1980gune42
andauthored
Don't assert for 8/16-bit integers element type vector (#850)
Co-authored-by: Gune <[email protected]>
1 parent 5c647e7 commit 9fd715c

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

lib/Utils/XeCommon.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ encodeVectorType(mlir::ConversionPatternRewriter &rewriter,
7474
bool enforceInteger) {
7575
mlir::Type srcElemType = type.getElementType();
7676
assert((srcElemType.isF16() || srcElemType.isBF16() || srcElemType.isF32() ||
77+
srcElemType.isInteger(8) || srcElemType.isInteger(16) ||
7778
srcElemType.isInteger(32) || srcElemType.isInteger(64)) &&
7879
"Unsupported vector element type.");
7980
const uint32_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc-rawsend-false.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc-rawsend-false.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
module @gemm attributes {gpu.container_module} {
10+
func.func @test(%arg0: memref<8x16xi8>, %arg1: memref<8x16xi8>, %arg2: memref<8x16xi16>, %arg3: memref<8x16xi16>) -> (memref<8x16xi32>, memref<8x16xi32>) attributes {llvm.emit_c_interface} {
11+
%c1 = arith.constant 1 : index
12+
%c8 = arith.constant 8 : index
13+
14+
%memref = gpu.alloc host_shared () : memref<8x16xi8>
15+
memref.copy %arg0, %memref : memref<8x16xi8> to memref<8x16xi8>
16+
%memref_1 = gpu.alloc host_shared () : memref<8x16xi8>
17+
memref.copy %arg1, %memref_1 : memref<8x16xi8> to memref<8x16xi8>
18+
%memref_2 = gpu.alloc host_shared () : memref<8x16xi32>
19+
20+
%memref_3 = gpu.alloc host_shared () : memref<8x16xi16>
21+
memref.copy %arg2, %memref_3 : memref<8x16xi16> to memref<8x16xi16>
22+
%memref_4 = gpu.alloc host_shared () : memref<8x16xi16>
23+
memref.copy %arg3, %memref_4 : memref<8x16xi16> to memref<8x16xi16>
24+
%memref_5 = gpu.alloc host_shared () : memref<8x16xi32>
25+
26+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x16xi8>, %memref_1 : memref<8x16xi8>, %memref_3 : memref<8x16xi16>, %memref_4 : memref<8x16xi16>, %memref_2 : memref<8x16xi32>, %memref_5 : memref<8x16xi32>)
27+
gpu.dealloc %memref : memref<8x16xi8>
28+
gpu.dealloc %memref_1 : memref<8x16xi8>
29+
gpu.dealloc %memref_3 : memref<8x16xi16>
30+
gpu.dealloc %memref_4 : memref<8x16xi16>
31+
return %memref_2, %memref_5 : memref<8x16xi32>, memref<8x16xi32>
32+
}
33+
34+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
35+
gpu.func @test_kernel(%arg0: memref<8x16xi8>, %arg1: memref<8x16xi8>, %arg2: memref<8x16xi16>, %arg3: memref<8x16xi16>, %arg4: memref<8x16xi32>, %arg5: memref<8x16xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
36+
%thread_id_x = gpu.thread_id x
37+
cf.br ^bb1
38+
^bb1:
39+
%0 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x16xi8> -> !xegpu.tensor_desc<16xi8>
40+
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xi8> -> vector<16xi8>
41+
%2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x16xi8> -> !xegpu.tensor_desc<16xi8>
42+
%3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16xi8> -> vector<16xi8>
43+
%4 = arith.addi %3, %1 : vector<16xi8>
44+
%5 = arith.extui %4 :vector<16xi8> to vector<16xi32>
45+
%6 = xegpu.create_nd_tdesc %arg4[%thread_id_x, 0] : memref<8x16xi32> -> !xegpu.tensor_desc<16xi32>
46+
xegpu.store_nd %5, %6 : vector<16xi32>, !xegpu.tensor_desc<16xi32>
47+
48+
%7 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0] : memref<8x16xi16> -> !xegpu.tensor_desc<16xi16>
49+
%8 = xegpu.load_nd %7 : !xegpu.tensor_desc<16xi16> -> vector<16xi16>
50+
%9 = xegpu.create_nd_tdesc %arg3[%thread_id_x, 0] : memref<8x16xi16> -> !xegpu.tensor_desc<16xi16>
51+
%10 = xegpu.load_nd %9 : !xegpu.tensor_desc<16xi16> -> vector<16xi16>
52+
%11 = arith.addi %8, %10 : vector<16xi16>
53+
%12 = arith.extui %11 :vector<16xi16> to vector<16xi32>
54+
%13 = xegpu.create_nd_tdesc %arg5[%thread_id_x, 0] : memref<8x16xi32> -> !xegpu.tensor_desc<16xi32>
55+
xegpu.store_nd %12, %13 : vector<16xi32>, !xegpu.tensor_desc<16xi32>
56+
57+
gpu.return
58+
}
59+
}
60+
func.func @main() attributes {llvm.emit_c_interface} {
61+
%c0 = arith.constant 0 : index
62+
%c1 = arith.constant 1 : index
63+
%c8 = arith.constant 8 : index
64+
%c16 = arith.constant 16 : index
65+
66+
%A = memref.alloc() : memref<8x16xi8>
67+
%B = memref.alloc() : memref<8x16xi8>
68+
%C = memref.alloc() : memref<8x16xi16>
69+
%D = memref.alloc() : memref<8x16xi16>
70+
71+
scf.for %i = %c0 to %c8 step %c1 {
72+
scf.for %j = %c0 to %c16 step %c1 {
73+
%val = index.castu %j : index to i8
74+
%val2 = index.castu %j : index to i16
75+
memref.store %val, %A[%i, %j] : memref<8x16xi8>
76+
memref.store %val, %B[%i, %j] : memref<8x16xi8>
77+
memref.store %val2, %C[%i, %j] : memref<8x16xi16>
78+
memref.store %val2, %D[%i, %j] : memref<8x16xi16>
79+
}
80+
}
81+
82+
%res0, %res1 = call @test(%A, %B, %C, %D) : (memref<8x16xi8>, memref<8x16xi8>, memref<8x16xi16>, memref<8x16xi16>) -> (memref<8x16xi32>, memref<8x16xi32>)
83+
84+
%res0_cast = memref.cast %res0 : memref<8x16xi32> to memref<*xi32>
85+
%res1_cast = memref.cast %res1 : memref<8x16xi32> to memref<*xi32>
86+
87+
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
88+
// CHECK-SAME: rank = 2 offset = 0 sizes = [8, 16] strides = [16, 1] data =
89+
// CHECK-NEXT{LITERAL}: [[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
90+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
91+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
92+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
93+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
94+
95+
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
96+
// CHECK-SAME: rank = 2 offset = 0 sizes = [8, 16] strides = [16, 1] data =
97+
// CHECK-NEXT{LITERAL}: [[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
98+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
99+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
100+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
101+
// CHECK-NEXT{LITERAL}: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
102+
103+
call @printMemrefI32(%res0_cast) : (memref<*xi32>) -> ()
104+
call @printMemrefI32(%res1_cast) : (memref<*xi32>) -> ()
105+
return
106+
}
107+
func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface}
108+
}

0 commit comments

Comments
 (0)