Skip to content

Commit 9888c84

Browse files
committed
fix
1 parent 1c06920 commit 9888c84

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
327327
let hasVerifier = 1;
328328
}
329329

330-
def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
331-
AllElementTypesMatch<["value", "TensorDesc"]>]> {
330+
def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
332331
let summary = "stores a n-D block register region back to memory, currently only supports 2D";
333332

334333
let description = [{

mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ getDistributedTensorDescType(xegpu::TensorDescType originalT,
170170
for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
171171
if (!divisible(APInt(64, o), APInt(64, l)))
172172
return failure();
173-
distributedShape.push_back(o / l);
173+
// Tensor descriptor is distributed only for the scattered case.
174+
if (originalT.isScattered())
175+
distributedShape.push_back(o / l);
176+
else
177+
distributedShape.push_back(o);
174178
}
175179
xegpu::TensorDescType distributedDescType;
176180
if (originalT.isScattered()) {

0 commit comments

Comments
 (0)