Skip to content

Commit 4612f64

Browse files
committed
Use min shape for dist_unit
1 parent cd4f9fc commit 4612f64

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
155155
SmallVector<int64_t> distUnitShape(sgLayout.size());
156156
SmallVector<Value> localOffset(sgLayout.size());
157157
for (size_t i = 0; i < sgLayout.size(); i++) {
158-
distUnitShape[i] = sgLayout[i] * sgShape[i];
158+
distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
159159
localOffset[i] =
160160
rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
161161
}

0 commit comments

Comments
 (0)