Skip to content

Commit 537ca0e

Browse files
committed
add missing logic
1 parent 99c340b commit 537ca0e

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/ADT/ArrayRef.h"
3535
#include "llvm/ADT/STLExtras.h"
3636
#include "llvm/ADT/SmallVector.h"
37+
#include "llvm/ADT/SmallVectorExtras.h"
3738

3839
namespace mlir {
3940
namespace xegpu {
@@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
876877
// Step 3: Apply subgroup to workitem distribution patterns.
877878
RewritePatternSet patterns(&getContext());
878879
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
879-
// TODO: distributionFn and shuffleFn are not used at this point.
880+
// distributionFn is used by vector distribution patterns to determine the
881+
// distributed vector type for a given vector value. In XeGPU subgroup
882+
// distribution context, we compute this based on lane layout.
880883
auto distributionFn = [](Value val) {
881884
VectorType vecType = dyn_cast<VectorType>(val.getType());
882885
int64_t vecRank = vecType ? vecType.getRank() : 0;
883-
OpBuilder builder(val.getContext());
884886
if (vecRank == 0)
885887
return AffineMap::get(val.getContext());
886-
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
888+
// Get the layout of the vector type.
889+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
890+
// If no layout is specified, assume the inner most dimension is distributed
891+
// for now.
892+
if (!layout)
893+
return AffineMap::getMultiDimMapWithTargets(
894+
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
895+
SmallVector<unsigned int> distributedDims;
896+
// Get the distributed dimensions based on the layout.
897+
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
898+
for (unsigned i = 0; i < laneLayout.size(); ++i) {
899+
if (laneLayout[i] > 1)
900+
distributedDims.push_back(i);
901+
}
902+
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
903+
val.getContext());
887904
};
905+
// TODO: shuffleFn is not used.
888906
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
889907
int64_t warpSz) { return Value(); };
890908
vector::populatePropagateWarpVectorDistributionPatterns(

0 commit comments

Comments
 (0)