Skip to content

Commit 21f50c0

Browse files
committed
fix issues
1 parent 270b498 commit 21f50c0

File tree

2 files changed

+29
-27
lines changed

2 files changed

+29
-27
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
378378
// tensorSize must be adjusted for array_length.
379379
tensorSize *= getArrayLength();
380380

381+
if (layout.getRank() == 1) {
382+
return VectorType::get({tensorSize / sgSize}, getElementType());
383+
}
384+
381385
return VectorType::get({tensorSize / (sgSize * laneDataSize), laneDataSize},
382386
getElementType());
383387
}

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/ADT/STLExtras.h"
3737
#include "llvm/ADT/SmallVector.h"
3838
#include "llvm/ADT/TypeSwitch.h"
39+
#include "llvm/ADT/bit.h"
3940
#include "llvm/Support/Casting.h"
4041
#include "llvm/Support/LogicalResult.h"
4142
#include "llvm/Support/raw_ostream.h"
@@ -781,30 +782,27 @@ namespace {
781782
/// | 2x32x16 | [1, 16] | 2x32x1 |
782783
FailureOr<VectorType> getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
783784
VectorType originalType) {
784-
llvm::SmallVector<int64_t, 2> distributedShape;
785785
if (!layout)
786786
return failure();
787787

788-
auto laneLayout = layout.getLaneLayout();
789-
assert((originalType.getRank() == 2 || originalType.getRank() == 3) &&
790-
"expecting 2D or 3D shape for the original vector type");
791-
assert(laneLayout.size() == 2 && "expecting 2D shape for the wi layout");
792-
// Original type can be 2D or 3D (array_length > 1), the last two dims are the
793-
// block shape.
794-
auto blockShape = originalType.getShape().take_back(2);
795-
// Check if the block vector shape can be distributed evenly.
796-
if (blockShape[0] % laneLayout[0] != 0 || blockShape[1] % laneLayout[1] != 0)
797-
return failure();
798-
799-
if (originalType.getRank() == 3) {
800-
distributedShape.push_back(originalType.getShape()[0]);
801-
}
802-
for (unsigned i = 0; i < 2; ++i) {
803-
distributedShape.push_back(blockShape[i] / laneLayout[i]);
788+
auto laneLayout = layout.getLaneLayout().asArrayRef();
789+
assert(originalType.getShape().size() >= laneLayout.size() &&
790+
"Rank of the original vector type should be greater or equal to the "
791+
"size of the lane layout to distribute the vector type.");
792+
SmallVector<int64_t> distributedShape(originalType.getShape());
793+
/// Only distribute the last `laneLayout.size()` dimensions. The remaining
794+
/// dimensions are not distributed.
795+
unsigned distributionStart = originalType.getRank() - laneLayout.size();
796+
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
797+
if (i < distributionStart) {
798+
continue;
799+
}
800+
/// Check if the dimension can be distributed evenly.
801+
if (dim % laneLayout[i - distributionStart] != 0)
802+
return failure();
803+
distributedShape[i] = dim / laneLayout[i - distributionStart];
804804
}
805-
auto newVectorType =
806-
VectorType::get(distributedShape, originalType.getElementType());
807-
return newVectorType;
805+
return VectorType::get(distributedShape, originalType.getElementType());
808806
}
809807

810808
static VectorType getDistributedVectorType(xegpu::LayoutAttr layout,
@@ -1028,15 +1026,14 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10281026
return rewriter.notifyMatchFailure(
10291027
storeOp, "the source tensor descriptor lacks sg_map attribute");
10301028

1031-
if (storeOp.getTensorDescType().getShape().size() != 2)
1032-
return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
1033-
1034-
auto distriburtedTypeByWarpOp =
1029+
auto distributedTypeByWarpOpOrFailure =
10351030
getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
1036-
if (failed(distriburtedTypeByWarpOp))
1031+
if (failed(distributedTypeByWarpOpOrFailure))
10371032
return rewriter.notifyMatchFailure(storeOp,
10381033
"Failed to distribute the type");
1039-
VectorType distributedTypeByWarpOp = distriburtedTypeByWarpOp.value();
1034+
VectorType distributedTypeByWarpOp =
1035+
distributedTypeByWarpOpOrFailure.value();
1036+
llvm::errs() << "distributed type: " << distributedTypeByWarpOp << "\n";
10401037

10411038
SmallVector<size_t> newRetIndices;
10421039
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1066,7 +1063,8 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10661063
newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[1]));
10671064

10681065
rewriter.create<xegpu::StoreNdOp>(newWarpOp.getLoc(), TypeRange{},
1069-
newStoreOperands, storeOp->getAttrs());
1066+
newStoreOperands);
1067+
storeOp->setDialectAttrs(storeOp->getDialectAttrs());
10701068
rewriter.eraseOp(storeOp);
10711069
return success();
10721070
}

0 commit comments

Comments
 (0)