Skip to content

Commit bbd0364

Browse files
committed
[XeGPU] Account for sg_map in CreateNd, LoadNd verification
1 parent a79098b commit bbd0364

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ LogicalResult CreateNdDescOp::verify() {
198198
tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
199199
return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
200200

201+
if (auto attr = getType().getSGMapAttr()) {
202+
auto wiLayout = attr.getWiLayout();
203+
auto wiData = attr.getWiData();
204+
if (wiData[0] < 1 || wiData[1] < 1 || (wiData[0] > 1 && wiData[1] > 1))
205+
return emitOpError() << "`wi_data` values must be >=1 and can only be >1 "
206+
"along one dimension."
207+
<< "\n";
208+
auto tdescShape = getType().getShape();
209+
for (size_t i = 0; i < tdescShape.size(); i++) {
210+
if (tdescShape[i] % wiLayout[i])
211+
return emitOpError() << "Work-items must uniformly divide a tile "
212+
"(tdescShape[i] % wiLayout[i] == 0)"
213+
<< "\n";
214+
}
215+
}
216+
201217
return success();
202218
}
203219

@@ -250,6 +266,13 @@ LogicalResult LoadNdOp::verify() {
250266
auto tdescShape = getShapeOf(tdescTy);
251267
auto valueShape = getShapeOf(valueTy);
252268

269+
if (auto attr = getTensorDescType().getSGMapAttr()) {
270+
auto wiLayout = attr.getWiLayout();
271+
for (size_t i = 0; i < tdescShape.size(); i++) {
272+
tdescShape[i] /= wiLayout[i];
273+
}
274+
}
275+
253276
if (getTranspose()) {
254277
auto trans = getTranspose().value();
255278

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,33 @@ gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
2121
gpu.return
2222
}
2323

24+
// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<32x32xi8>) {
25+
gpu.func @test_load_nd_tdesc_with_sg_map(%src: memref<32x32xi8>) {
26+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
27+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
28+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
29+
%2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
30+
gpu.return
31+
}
32+
33+
// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_2(%[[arg0:.*]]: memref<24x32xf32>) {
34+
gpu.func @test_load_nd_tdesc_with_sg_map_2(%src: memref<24x32xf32>) {
35+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
36+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
37+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
38+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
39+
gpu.return
40+
}
41+
42+
// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_3(%[[arg0:.*]]: memref<32x32xf32>) {
43+
gpu.func @test_load_nd_tdesc_with_sg_map_3(%src: memref<32x32xf32>) {
44+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
45+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
46+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
47+
%2 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
48+
gpu.return
49+
}
50+
2451
// CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
2552
gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
2653
//CHECK: %[[C:.*]] = arith.constant 1 : index

0 commit comments

Comments
 (0)