Skip to content

Commit 385cc49

Browse files
authored
XeVM tests enabling (#1069)
1 parent d3017f3 commit 385cc49

File tree

6 files changed

+22
-66
lines changed

6 files changed

+22
-66
lines changed

lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,6 @@ class CreateNdDescToXeVMPattern
121121
xegpu::CreateNdDescOp::Adaptor adaptor,
122122
ConversionPatternRewriter &rewriter) const override {
123123
auto loc = op.getLoc();
124-
auto resultDesc = cast<TensorDescType>(op.getResult().getType());
125-
auto sgMap = resultDesc.getLayoutAttr();
126-
if (!sgMap) {
127-
op.emitError() << "XeVM expects SGMap attribute to be present for tensor "
128-
"descriptors";
129-
return mlir::failure();
130-
}
131124
auto source = op.getSource();
132125
Type payloadElemTy = rewriter.getI32Type();
133126
Type i64Ty = rewriter.getI64Type();
@@ -292,8 +285,7 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
292285
auto l3 = translateStoreXeGPUCacheHint(op.getL3Hint());
293286
VectorType srcFlatVecTy =
294287
VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
295-
Value srcFlatVec = rewriter.create<vector::ShapeCastOp>(loc, srcFlatVecTy,
296-
op.getValue());
288+
Value srcFlatVec = op.getValue();
297289
srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
298290
rewriter.getIntegerType(elemBitSize));
299291
srcFlatVec =
@@ -327,9 +319,7 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
327319
resultFlatVec = rewriter.create<vector::BitCastOp>(
328320
loc, encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
329321
resultFlatVec);
330-
auto newOp =
331-
rewriter.create<vector::ShapeCastOp>(loc, dstVecTy, resultFlatVec);
332-
rewriter.replaceOp(op, newOp);
322+
rewriter.replaceOp(op, resultFlatVec);
333323
}
334324
}
335325
return success();
@@ -548,14 +538,8 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
548538
}
549539
auto rc = IntegerAttr::get(rewriter.getI32Type(), 8);
550540

551-
VectorType aNty =
552-
VectorType::get(aTy.getNumElements(), aTy.getElementType());
553-
Value aVec = rewriter.create<vector::ShapeCastOp>(loc, aNty, op.getLhs());
554-
555-
VectorType bNty =
556-
VectorType::get(bTy.getNumElements(), bTy.getElementType());
557-
Value bVec = rewriter.create<vector::ShapeCastOp>(loc, bNty, op.getRhs());
558-
541+
Value aVec = op.getLhs();
542+
Value bVec = op.getRhs();
559543
auto cvecty = cast<VectorType>(c.getType());
560544
VectorType cNty =
561545
VectorType::get(cvecty.getNumElements(), cvecty.getElementType());

test/Conversion/XeGPUToXeVM/dpas.mlir

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,11 @@
55
#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
66

77
gpu.module @load_store_check {
8-
func.func @dpas(%a_loaded: vector<8x1xf16>, %b_loaded: vector<8x2xf16>, %c_loaded: vector<8x1xf32>) -> vector<8x1xf32> {
8+
//CHECK: func.func @dpas(%[[arg0:.*]]: vector<8xf16>, %[[arg1:.*]]: vector<16xf16>, %[[arg2:.*]]: vector<8xf32>) -> vector<8xf32>
9+
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
910
// Loads are checked in a separate test.
10-
// Cast arguments to SIMT-style vectors.
11-
//CHECK: %[[CAST_A:.*]] = vector.shape_cast %arg0 : vector<8x1xf16> to vector<8xf16>
12-
//CHECK-NEXT: %[[CAST_B:.*]] = vector.shape_cast %arg1 : vector<8x2xf16> to vector<16xf16>
13-
//CHECK-NEXT: %[[CAST_C:.*]] = vector.shape_cast %arg2 : vector<8x1xf32> to vector<8xf32>
14-
//CHECK-NEXT: %[[D:.*]] = xevm.dpas %[[CAST_C]], %[[CAST_A]], %[[CAST_B]] {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xf16>, vector<16xf16>) -> vector<8xf32>
15-
// Cast result back to expected shape
16-
//CHECK-NEXT: %[[CAST_D:.*]] = vector.shape_cast %[[D]] : vector<8xf32> to vector<8x1xf32>
17-
%d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} : vector<8x1xf16>, vector<8x2xf16>, vector<8x1xf32> -> vector<8x1xf32>
18-
return %d : vector<8x1xf32>
11+
//CHECK: %[[D:.*]] = xevm.dpas %[[arg2]], %[[arg0]], %[[arg1]] {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xf16>, vector<16xf16>) -> vector<8xf32>
12+
%d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
13+
return %d : vector<8xf32>
1914
}
2015
}

test/Conversion/XeGPUToXeVM/lit.local.cfg

Lines changed: 0 additions & 7 deletions
This file was deleted.

test/Conversion/XeGPUToXeVM/loadstore_nd.mlir

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ gpu.module @load_store_check {
1313
// CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
1414
// CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
1515
// CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
16-
%src_tdesc = xegpu.create_nd_tdesc %srcce[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
16+
%src_tdesc = xegpu.create_nd_tdesc %srcce[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
17+
1718

1819
//CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
1920
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
@@ -25,15 +26,14 @@ gpu.module @load_store_check {
2526
//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
2627
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
2728
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, transpose = false, vnni_transform = false, l1_cache_control = C, l3_cache_control = UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
28-
%loaded = xegpu.load_nd %src_tdesc <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x1xf32>
29+
%loaded = xegpu.load_nd %src_tdesc <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
2930
//CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
30-
//CHECK: %[[LD_LOADED_F32_DISTRIBUTED:.*]] = vector.shape_cast %[[LD_LOADED_F32]] : vector<8xf32> to vector<8x1xf32>
3131

3232
%tid_x = gpu.thread_id x
3333
%tid_x_i32 = arith.index_cast %tid_x : index to i32
3434
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
35-
//CHECK: %[[LOADED_F32_DISTRIBUTED_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32_DISTRIBUTED]] [0, 0] : f32 into vector<8x1xf32>
36-
%loaded_modified = vector.insert %tid_x_f32, %loaded[0, 0] : f32 into vector<8x1xf32>
35+
//CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
36+
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
3737

3838
// CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
3939
// CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
@@ -43,7 +43,7 @@ gpu.module @load_store_check {
4343
// CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
4444
// CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
4545
// CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
46-
%dst_tdesc = xegpu.create_nd_tdesc %dstte[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
46+
%dst_tdesc = xegpu.create_nd_tdesc %dstte[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
4747

4848
//CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
4949
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
@@ -54,10 +54,9 @@ gpu.module @load_store_check {
5454
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
5555
//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
5656
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
57-
//CHECK: %[[FLAT_VALUE:.*]] = vector.shape_cast %[[LOADED_F32_DISTRIBUTED_MODIFIED]] : vector<8x1xf32> to vector<8xf32>
58-
//CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[FLAT_VALUE]] : vector<8xf32> to vector<8xi32>
57+
//CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
5958
//CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, l1_cache_control = WB, l3_cache_control = UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
60-
xegpu.store_nd %loaded_modified, %dst_tdesc <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
59+
xegpu.store_nd %loaded_modified, %dst_tdesc <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
6160
gpu.return
6261
}
6362
}

test/Integration/Dialect/XeGPUToXeVM/lit.local.cfg

Lines changed: 0 additions & 13 deletions
This file was deleted.

test/Integration/Dialect/XeGPUToXeVM/loadstore_scatter_chunk_size_2.mlir

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,25 @@
22
// RUN: --runner imex-cpu-runner -e main \
33
// RUN: --entry-point-result=void \
44
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5-
6-
#sg_map_a_bf16 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
75
module @gemm attributes {gpu.container_module} {
86
gpu.module @kernel {
97
gpu.func @load_store_2d(%src: memref<128xf32, 1>, %dst: memref<128xf32, 1>) kernel {
108
%srcce = memref.memory_space_cast %src : memref<128xf32, 1> to memref<128xf32>
119
%dstte = memref.memory_space_cast %dst : memref<128xf32, 1> to memref<128xf32>
1210

1311
%offsets = arith.constant dense<[0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30]> : vector<16xindex>
14-
%src_tdesc = xegpu.create_tdesc %srcce, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #sg_map_a_bf16>
15-
%dst_tdesc = xegpu.create_tdesc %dstte, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #sg_map_a_bf16>
12+
%src_tdesc = xegpu.create_tdesc %srcce, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
13+
%dst_tdesc = xegpu.create_tdesc %dstte, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
1614

1715
%mask = arith.constant dense<[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]> : vector<16xi1>
18-
%loaded = xegpu.load %src_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #sg_map_a_bf16>, vector<16xi1> -> vector<2x1xf32>
16+
%loaded = xegpu.load %src_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1> -> vector<2xf32>
1917

2018
%tid_x = gpu.thread_id x
2119
%tid_x_i32 = arith.index_cast %tid_x : index to i32
2220
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
23-
%loaded_modified = vector.insert %tid_x_f32, %loaded[0,0] : f32 into vector<2x1xf32>
21+
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<2xf32>
2422

25-
xegpu.store %loaded_modified, %dst_tdesc, %mask <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #sg_map_a_bf16>, vector<16xi1>
23+
xegpu.store %loaded_modified, %dst_tdesc, %mask <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1>
2624
gpu.return
2725
}
2826
}

0 commit comments

Comments
 (0)