Skip to content

Commit c333f7d

Browse files
authored
[mlir][xegpu] Add layout based SIMT distribution support for vector.extract/insert_strided_slice (#168626)
This PR adds general SIMT distribution support for `vector.extract/insert_strided_slice`. Currently vector distribution already have support for these operations but have restrictions to avoid requiring layouts during distribution logic. For example, `extract_stride_slice` require that distributed dimension is fully extracted. However, more complex cases may require extracting partially from distributed dimension (eg. 8x16xf16 extraction from 8x32xf16). These types of cases need the layouts to reason about how the data is spread across SIMT lanes. Currently, we don't have layout access in vector distribution so these new patterns are place in XeGPU side. They have higher pattern benefit so that they will be tried first before trying regular vector distribution based patterns.
1 parent 44c9d3a commit c333f7d

File tree

2 files changed

+799
-291
lines changed

2 files changed

+799
-291
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 242 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
174174
return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
175175
}
176176

177+
/// Given a vector type and its distributed vector type, return the list of
178+
/// dimensions that are distributed.
179+
static SmallVector<int64_t> getDistributedDims(VectorType originalType,
180+
VectorType distributedType) {
181+
assert(originalType.getRank() == distributedType.getRank() &&
182+
"sequential and distributed vector types must have the same rank");
183+
SmallVector<int64_t> distributedDims;
184+
for (int64_t i = 0; i < originalType.getRank(); ++i) {
185+
if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
186+
distributedDims.push_back(i);
187+
}
188+
}
189+
return distributedDims;
190+
}
191+
177192
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
178193
/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
179194
/// contained within a WarpExecuteOnLane0Op.
@@ -1469,6 +1484,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
14691484
}
14701485
};
14711486

1487+
// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1488+
// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1489+
// advanced cases where the distributed dimension is partially extracted and
1490+
// currently not supported by the generic vector distribution patterns.
1491+
struct VectorExtractStridedSliceDistribution
1492+
: public gpu::WarpDistributionPattern {
1493+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1494+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1495+
PatternRewriter &rewriter) const override {
1496+
OpOperand *operand =
1497+
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1498+
if (!operand)
1499+
return failure();
1500+
auto extractOp =
1501+
cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
1502+
unsigned operandIdx = operand->getOperandNumber();
1503+
auto distributedType =
1504+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
1505+
// Find the distributed dimensions.
1506+
auto extractResultType = cast<VectorType>(operand->get().getType());
1507+
auto distributedDims =
1508+
getDistributedDims(extractResultType, distributedType);
1509+
// Collect updated source type, sizes and offsets. They may be adjusted
1510+
// later if the data is distributed to lanes (as opposed to being owned by
1511+
// all lanes uniformly).
1512+
VectorType updatedSourceType = extractOp.getSourceVectorType();
1513+
SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1514+
extractOp.getSizes(), [](Attribute attr) { return attr; });
1515+
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1516+
extractOp.getOffsets(), [](Attribute attr) { return attr; });
1517+
// If the result is distributed, it must be distributed in exactly one
1518+
// dimension. In this case, we adjust the sourceDistType, distributedSizes
1519+
// and distributedOffsets accordingly.
1520+
if (distributedDims.size() > 0) {
1521+
if (distributedDims.size() != 1)
1522+
return rewriter.notifyMatchFailure(
1523+
warpOp, "Source can not be distributed in multiple dimensions.");
1524+
int64_t distributedDim = distributedDims[0];
1525+
int sourceDistrDimSize =
1526+
extractOp.getSourceVectorType().getShape()[distributedDim];
1527+
auto sourceLayout =
1528+
xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
1529+
if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1530+
return rewriter.notifyMatchFailure(
1531+
warpOp, "the source of extract_strided_slice op lacks distribution "
1532+
"layout");
1533+
auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1534+
// Because only single dimension distribution is supported, lane layout
1535+
// size at the distributed dim must be the subgroup size.
1536+
int subgroupSize = sourceLaneLayout[distributedDim];
1537+
// Check if the source size in the distributed dimension is a multiple of
1538+
// subgroup size.
1539+
if (sourceDistrDimSize % subgroupSize != 0)
1540+
return rewriter.notifyMatchFailure(
1541+
warpOp,
1542+
"Source size along distributed dimension is not a multiple of "
1543+
"subgroup size.");
1544+
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1545+
// We expect lane data to be all ones in this case.
1546+
if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1547+
return rewriter.notifyMatchFailure(
1548+
warpOp, "Expecting unit lane data in source layout");
1549+
// The offsets in the distributed dimention must be a multiple of subgroup
1550+
// size.
1551+
int64_t distrDimOffset =
1552+
cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
1553+
if (distrDimOffset % subgroupSize != 0)
1554+
return rewriter.notifyMatchFailure(
1555+
warpOp, "Offset along distributed dimension "
1556+
"is not a multiple of subgroup size.");
1557+
updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1558+
sourceLayout, extractOp.getSourceVectorType())
1559+
.value();
1560+
// Update the distributed sizes to match the distributed type.
1561+
updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1562+
distributedType.getDimSize(distributedDim));
1563+
// Update the distributed offsets to match round robin distribution (i.e.
1564+
// each lane owns data at `subgroupSize` stride given unit lane data).
1565+
updatedOffsets[distributedDim] =
1566+
rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1567+
}
1568+
// Do the distribution by yielding the source of the extract op from
1569+
// the warp op and creating a new extract op outside the warp op.
1570+
SmallVector<size_t> newRetIndices;
1571+
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1572+
rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1573+
newRetIndices);
1574+
rewriter.setInsertionPointAfter(newWarpOp);
1575+
Value source = newWarpOp.getResult(newRetIndices[0]);
1576+
// Create a new extract op outside the warp op.
1577+
Value newExtractOp = vector::ExtractStridedSliceOp::create(
1578+
rewriter, extractOp.getLoc(), distributedType, source,
1579+
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1580+
ArrayAttr::get(rewriter.getContext(), updatedSizes),
1581+
extractOp.getStrides());
1582+
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
1583+
return success();
1584+
}
1585+
};
1586+
1587+
/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1588+
/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1589+
/// advanced cases where the distributed dimension is partially inserted and
1590+
/// currently not supported by the generic vector distribution patterns.
1591+
struct VectorInsertStridedSliceDistribution
1592+
: public gpu::WarpDistributionPattern {
1593+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1594+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1595+
PatternRewriter &rewriter) const override {
1596+
OpOperand *operand =
1597+
getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1598+
if (!operand)
1599+
return failure();
1600+
unsigned int operandNumber = operand->getOperandNumber();
1601+
auto insertOp =
1602+
operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1603+
auto distributedType =
1604+
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1605+
// Find the distributed dimensions of the dest vector.
1606+
auto insertResultType = cast<VectorType>(operand->get().getType());
1607+
auto destDistributedDims =
1608+
getDistributedDims(insertResultType, distributedType);
1609+
// Collect updated offsets, source type and dest type. They may be adjusted
1610+
// later if the data is distributed to lanes (as opposed to being owned by
1611+
// all lanes uniformly).
1612+
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1613+
insertOp.getOffsets(), [](Attribute attr) { return attr; });
1614+
VectorType updatedSourceType = insertOp.getSourceVectorType();
1615+
VectorType updatedDestType = insertOp.getDestVectorType();
1616+
if (destDistributedDims.size() > 0) {
1617+
// Only single dimension distribution is supported.
1618+
if (destDistributedDims.size() != 1)
1619+
return rewriter.notifyMatchFailure(
1620+
warpOp,
1621+
"Expecting source to be distributed in a single dimension.");
1622+
int64_t destDistributedDim = destDistributedDims[0];
1623+
1624+
VectorType srcType = insertOp.getSourceVectorType();
1625+
VectorType destType = insertOp.getDestVectorType();
1626+
// Currently we require that both source (kD) and dest (nD) vectors are
1627+
// distributed. This requires that distributedDim (d) is contained in the
1628+
// last k dims of the dest vector (d >= n - k).
1629+
int64_t sourceDistributedDim =
1630+
destDistributedDim - (destType.getRank() - srcType.getRank());
1631+
if (sourceDistributedDim < 0)
1632+
return rewriter.notifyMatchFailure(
1633+
insertOp,
1634+
"distributed dimension must be in the last k (i.e. source "
1635+
"rank) dims of dest vector");
1636+
int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1637+
// Obtain the source and dest layouts.
1638+
auto destLayout =
1639+
xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
1640+
auto sourceLayout =
1641+
xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
1642+
if (!destLayout || !sourceLayout ||
1643+
destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1644+
sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1645+
return rewriter.notifyMatchFailure(
1646+
warpOp, "the source or dest of insert_strided_slice op lacks "
1647+
"distribution layout");
1648+
// Because only single dimension distribution is supported, lane layout
1649+
// size at the distributed dim must be the subgroup size.
1650+
int subgroupSize =
1651+
destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1652+
// We require that source and dest lane data are all ones to ensure
1653+
// uniform round robin distribution.
1654+
auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1655+
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1656+
if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1657+
!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1658+
return rewriter.notifyMatchFailure(
1659+
warpOp, "Expecting unit lane data in source and dest layouts");
1660+
// Source distributed dim size must be multiples of subgroup size.
1661+
if (srcDistrDimSize % subgroupSize != 0)
1662+
return rewriter.notifyMatchFailure(
1663+
warpOp, "Distributed dimension size in source is not a multiple of "
1664+
"subgroup size.");
1665+
// Offsets in the distributed dimension must be multiples of subgroup
1666+
// size.
1667+
int64_t destDistrDimOffset =
1668+
cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1669+
if (destDistrDimOffset % subgroupSize != 0)
1670+
return rewriter.notifyMatchFailure(
1671+
warpOp,
1672+
"Offset along distributed dimension in dest is not a multiple of "
1673+
"subgroup size.");
1674+
// Update the source and dest types based on their layouts.
1675+
updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1676+
sourceLayout, insertOp.getSourceVectorType())
1677+
.value();
1678+
updatedDestType = getDistVecTypeBasedOnLaneLayout(
1679+
destLayout, insertOp.getDestVectorType())
1680+
.value();
1681+
// Update the distributed offsets to match round robin distribution (i.e.
1682+
// each lane owns data at `subgroupSize` stride given unit lane data).
1683+
updatedOffsets[destDistributedDim] =
1684+
rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1685+
}
1686+
// Do the distribution by yielding the source and dest of the insert op
1687+
// from the warp op and creating a new insert op outside the warp op.
1688+
SmallVector<size_t> newRetIndices;
1689+
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1690+
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1691+
{updatedSourceType, updatedDestType}, newRetIndices);
1692+
rewriter.setInsertionPointAfter(newWarpOp);
1693+
1694+
Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1695+
Value dest = newWarpOp.getResult(newRetIndices[1]);
1696+
// Create a new insert op outside the warp op.
1697+
Value newInsertOp = vector::InsertStridedSliceOp::create(
1698+
rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1699+
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1700+
insertOp.getStrides());
1701+
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1702+
newInsertOp);
1703+
return success();
1704+
}
1705+
};
1706+
14721707
/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
14731708
/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
14741709
/// outside of the warp op.
@@ -1626,9 +1861,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
16261861
MemrefExtractAlignedPointerAsIndexDistribution>(
16271862
patterns.getContext(),
16281863
/*pattern benefit=*/regularPatternBenefit);
1629-
patterns.add<VectorShapeCastDistribution>(
1630-
patterns.getContext(),
1631-
/*pattern benefit=*/highPatternBenefit);
1864+
// For following patterns, we need to override the regular vector distribution
1865+
// patterns. Therefore, assign higher benefit.
1866+
patterns
1867+
.add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
1868+
VectorInsertStridedSliceDistribution>(
1869+
patterns.getContext(),
1870+
/*pattern benefit=*/highPatternBenefit);
16321871
}
16331872

16341873
void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(

0 commit comments

Comments
 (0)