Skip to content

Commit 9213451

Browse files
committed
Prevent non 2d shaped loads/stores to have an sg_map
1 parent dfca5d6 commit 9213451

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,15 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
8181
// each dimension.
8282
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
8383
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84-
if (descShape == valShape)
85-
return true;
84+
if (descShape == valShape) {
85+
if (!sgMap)
86+
return true;
87+
88+
// this can be relaxed if necessary by supporting non-2d shapes distribution
89+
// until the constraints are defined this lives here instead of the tensor
90+
// descriptor type.
91+
return valShape.size() == sgMap.getWiLayout().size();
92+
}
8693

8794
if (!sgMap)
8895
return false;

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
8686
return
8787
}
8888

89+
// -----
90+
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
91+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
92+
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
93+
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
94+
%2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
95+
return
96+
}
97+
8998
// -----
9099
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
91100
%1 = arith.constant dense<1.0>: vector<24x32xf16>

0 commit comments

Comments
 (0)