Skip to content

Commit 132bcff

Browse files
committed
clean up and update the test
1 parent 7717fa7 commit 132bcff

File tree

6 files changed

+30
-41
lines changed

6 files changed

+30
-41
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -131,62 +131,51 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
131131
return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
132132
}
133133

134-
BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
135-
return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
136-
}
137-
138-
ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const {
139-
return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
134+
template <typename T,
135+
typename = std::enable_if_t<
136+
std::is_same_v<T, BlockTensorDescAttr> ||
137+
std::is_same_v<T, ScatterTensorDescAttr>>>
138+
T getEncodingOfType() const {
139+
return llvm::dyn_cast_if_present<T>(getEncoding());
140140
}
141141

142142
LayoutAttr getLayoutAttr() const {
143143
return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
144144
}
145145

146146
xegpu::MemorySpace getMemorySpace() const {
147-
auto block_attr = getEncodingAsBlockTensorDescAttr();
148-
if (block_attr && block_attr.getMemorySpace())
149-
return block_attr.getMemorySpace().getValue();
147+
if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
148+
return attr.getMemorySpace().getValue();
150149

151-
auto scatter_attr = getEncodingAsScatterTensorDescAttr();
152-
if (scatter_attr && scatter_attr.getMemorySpace())
153-
return scatter_attr.getMemorySpace().getValue();
150+
if (auto attr = getEncodingOfType<ScatterTensorDescAttr>())
151+
return attr.getMemorySpace().getValue();
154152

155-
// return default value
153+
llvm_unreachable("invalid encoding");
156154
return MemorySpace::Global;
157155
}
158156

159157
// get the ArrayLength for blocked TensorDesc
160158
int getArrayLength() {
161-
auto attr = getEncoding();
162-
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
163-
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
164-
if (block_attr && block_attr.getArrayLength())
165-
return block_attr.getArrayLength().getInt();
166-
// return default value
167-
return 1;
159+
auto attr = getEncodingOfType<BlockTensorDescAttr>();
160+
assert(attr && "invalid on non BlockTensorDescAttr.");
161+
return attr.getArrayLength().getInt();
168162
}
169163

170164
bool getBoundaryCheck() {
171-
auto attr = getEncoding();
172-
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
173-
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
174-
if (block_attr && block_attr.getBoundaryCheck())
175-
return block_attr.getBoundaryCheck().getValue();
176-
// return default value
177-
return true;
165+
auto attr = getEncodingOfType<BlockTensorDescAttr>();
166+
assert(attr && "invalid on non BlockTensorDescAttr.");
167+
return attr.getBoundaryCheck().getValue();
178168
}
179169

180170
bool isScattered() {
181-
return bool(getEncodingAsScatterTensorDescAttr());
171+
return bool(getEncodingOfType<ScatterTensorDescAttr>());
182172
}
183173

184174
// get the ChunkSize for scattered TensorDesc
185175
int getChunkSizeAsInt() {
186-
auto attr = getEncoding();
187-
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
188-
assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
189-
return scatter_attr.getChunkSizeAsInt();
176+
auto attr = getEncodingOfType<ScatterTensorDescAttr>();
177+
assert(attr && "invalid on non ScatterTensorDescAttr.");
178+
return attr.getChunkSizeAsInt();
190179
}
191180

192181
/// Helper to drop all layout information from the TensorDesc type.

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
5454
std::multiplies<int64_t>());
5555

5656
// Case 1: regular loads/stores
57-
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
57+
auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
5858
if (scatterAttr) {
5959
auto chunkSize = scatterAttr.getChunkSize().getInt();
6060
// Verify if the first dimension of the tensor descriptor shape is

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
3030
// CHECK-SAME: %[[OFFSET:.+]]: index
3131
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
3232
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
33-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
33+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
3434
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
3535
// CHECK: return %[[VEC]]
3636

@@ -55,7 +55,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
5555
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
5656
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
5757
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
58-
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
58+
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
5959
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
6060
// CHECK: return %[[VEC]]
6161

@@ -73,7 +73,7 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
7373
// CHECK-SAME: %[[OFFSET:.+]]: index
7474
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
7575
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
76-
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
76+
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
7777
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
7878
// CHECK: return %[[VEC]]
7979

mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
3232
// CHECK-SAME: %[[OFFSET:.+]]: index
3333
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
3434
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
35-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
35+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
3636
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
3737

3838
// -----
@@ -57,7 +57,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
5757
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
5858
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
5959
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
60-
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
60+
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
6161
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
6262

6363
// -----
@@ -75,7 +75,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
7575
// CHECK-SAME: %[[OFFSET:.+]]: index
7676
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
7777
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
78-
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
78+
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
7979
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
8080

8181
// -----

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
5151
// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
5252
// CHECK-SAME: %[[OFFSET:.+]]: index
5353
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
54-
// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
54+
// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
5555
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
5656
// CHECK: return %[[VEC]]
5757

mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
8080
// CHECK-SAME: %[[OFFSET:.+]]: index
8181
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
8282
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
83-
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
83+
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
8484
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
8585

8686
// -----

0 commit comments

Comments
 (0)