|
34 | 34 | #include "llvm/ADT/ArrayRef.h" |
35 | 35 | #include "llvm/ADT/STLExtras.h" |
36 | 36 | #include "llvm/ADT/SmallVector.h" |
| 37 | +#include "llvm/ADT/SmallVectorExtras.h" |
37 | 38 |
|
38 | 39 | namespace mlir { |
39 | 40 | namespace xegpu { |
@@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() { |
876 | 877 | // Step 3: Apply subgroup to workitem distribution patterns. |
877 | 878 | RewritePatternSet patterns(&getContext()); |
878 | 879 | xegpu::populateXeGPUSubgroupDistributePatterns(patterns); |
879 | | - // TODO: distributionFn and shuffleFn are not used at this point. |
| 880 | + // distributionFn is used by vector distribution patterns to determine the |
| 881 | + // distributed vector type for a given vector value. In XeGPU subgroup |
| 882 | + // distribution context, we compute this based on lane layout. |
880 | 883 | auto distributionFn = [](Value val) { |
881 | 884 | VectorType vecType = dyn_cast<VectorType>(val.getType()); |
882 | 885 | int64_t vecRank = vecType ? vecType.getRank() : 0; |
883 | | - OpBuilder builder(val.getContext()); |
884 | 886 | if (vecRank == 0) |
885 | 887 | return AffineMap::get(val.getContext()); |
886 | | - return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); |
| 888 | + // Get the layout of the vector type. |
| 889 | + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); |
| 890 | + // If no layout is specified, assume the inner most dimension is distributed |
| 891 | + // for now. |
| 892 | + if (!layout) |
| 893 | + return AffineMap::getMultiDimMapWithTargets( |
| 894 | + vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext()); |
| 895 | + SmallVector<unsigned int> distributedDims; |
| 896 | + // Get the distributed dimensions based on the layout. |
| 897 | + ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef(); |
| 898 | + for (unsigned i = 0; i < laneLayout.size(); ++i) { |
| 899 | + if (laneLayout[i] > 1) |
| 900 | + distributedDims.push_back(i); |
| 901 | + } |
| 902 | + return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, |
| 903 | + val.getContext()); |
887 | 904 | }; |
| 905 | + // TODO: shuffleFn is not used. |
888 | 906 | auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, |
889 | 907 | int64_t warpSz) { return Value(); }; |
890 | 908 | vector::populatePropagateWarpVectorDistributionPatterns( |
|
0 commit comments