Skip to content

Commit 6a0db88

Browse files
committed
handle simple cases
1 parent 2c97c98 commit 6a0db88

File tree

2 files changed

+191
-129
lines changed

2 files changed

+191
-129
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 143 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,68 +1505,72 @@ struct VectorExtractStridedSliceDistribution
15051505
auto extractResultType = cast<VectorType>(operand->get().getType());
15061506
auto distributedDims =
15071507
getDistributedDims(extractResultType, distributedType);
1508-
// Only single dimension distribution is supported.
1509-
if (distributedDims.size() != 1)
1510-
return rewriter.notifyMatchFailure(
1511-
warpOp, "Expecting source to be distributed in a single dimension.");
1512-
int64_t distributedDim = distributedDims[0];
1513-
int sourceDistrDimSize =
1514-
extractOp.getSourceVectorType().getShape()[distributedDim];
1515-
1516-
auto sourceLayout =
1517-
xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
1518-
if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1519-
return rewriter.notifyMatchFailure(
1520-
warpOp, "the source of extract_strided_slice op lacks distribution "
1521-
"layout");
1522-
auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1523-
// Because only single dimension distribution is supported, lane layout size
1524-
// at the distributed dim must be the subgroup size.
1525-
int subgroupSize = sourceLaneLayout[distributedDim];
1526-
// Check if the source size in the distributed dimension is a multiple of
1527-
// subgroup size.
1528-
if (sourceDistrDimSize % subgroupSize != 0)
1529-
return rewriter.notifyMatchFailure(
1530-
warpOp,
1531-
"Source size along distributed dimension is not a multiple of "
1532-
"subgroup size.");
1533-
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1534-
// We expect lane data to be all ones in this case.
1535-
if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1536-
return rewriter.notifyMatchFailure(
1537-
warpOp, "Expecting unit lane data in source layout");
1538-
// The offsets in the distributed dimention must be a multiple of subgroup
1539-
// size.
1540-
int64_t distrDimOffset =
1541-
cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
1542-
if (distrDimOffset % subgroupSize != 0)
1543-
return rewriter.notifyMatchFailure(warpOp,
1544-
"Offset along distributed dimension "
1545-
"is not a multiple of subgroup size.");
1546-
// Do the distribution by yielding the source of the extract op from
1547-
// the warp op and creating a new extract op outside the warp op.
1548-
VectorType sourceDistType =
1549-
getDistVecTypeBasedOnLaneLayout(sourceLayout,
1550-
extractOp.getSourceVectorType())
1551-
.value();
1508+
// Source distributed type must be adjusted for the distributed case.
1509+
VectorType sourceDistType = extractOp.getSourceVectorType();
1510+
// Distributed sizes and offsets must be adjusted for distributed case.
1511+
SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1512+
extractOp.getSizes(), [](Attribute attr) { return attr; });
1513+
SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
1514+
extractOp.getOffsets(), [](Attribute attr) { return attr; });
1515+
// If the result is distributed, it must be distributed in exactly one
1516+
// dimension. In this case, we adjust the sourceDistType, distributedSizes
1517+
// and distributedOffsets accordingly.
1518+
if (distributedDims.size() > 0) {
1519+
if (distributedDims.size() != 1)
1520+
return rewriter.notifyMatchFailure(
1521+
warpOp, "Source can not be distributed in multiple dimensions.");
1522+
int64_t distributedDim = distributedDims[0];
1523+
int sourceDistrDimSize =
1524+
extractOp.getSourceVectorType().getShape()[distributedDim];
1525+
auto sourceLayout =
1526+
xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
1527+
if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1528+
return rewriter.notifyMatchFailure(
1529+
warpOp, "the source of extract_strided_slice op lacks distribution "
1530+
"layout");
1531+
auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1532+
// Because only single dimension distribution is supported, lane layout
1533+
// size at the distributed dim must be the subgroup size.
1534+
int subgroupSize = sourceLaneLayout[distributedDim];
1535+
// Check if the source size in the distributed dimension is a multiple of
1536+
// subgroup size.
1537+
if (sourceDistrDimSize % subgroupSize != 0)
1538+
return rewriter.notifyMatchFailure(
1539+
warpOp,
1540+
"Source size along distributed dimension is not a multiple of "
1541+
"subgroup size.");
1542+
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1543+
// We expect lane data to be all ones in this case.
1544+
if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1545+
return rewriter.notifyMatchFailure(
1546+
warpOp, "Expecting unit lane data in source layout");
1547+
// The offsets in the distributed dimention must be a multiple of subgroup
1548+
// size.
1549+
int64_t distrDimOffset =
1550+
cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
1551+
if (distrDimOffset % subgroupSize != 0)
1552+
return rewriter.notifyMatchFailure(
1553+
warpOp, "Offset along distributed dimension "
1554+
"is not a multiple of subgroup size.");
1555+
// Do the distribution by yielding the source of the extract op from
1556+
// the warp op and creating a new extract op outside the warp op.
1557+
sourceDistType = getDistVecTypeBasedOnLaneLayout(
1558+
sourceLayout, extractOp.getSourceVectorType())
1559+
.value();
1560+
// Update the distributed sizes to match the distributed type.
1561+
distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1562+
distributedType.getDimSize(distributedDim));
1563+
// Update the distributed offsets to match round robin distribution (i.e.
1564+
// each lane owns data at `subgroupSize` stride given unit lane data).
1565+
distributedOffsets[distributedDim] =
1566+
rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1567+
}
15521568
// Create a new warp op that yields the source of the extract op.
15531569
SmallVector<size_t> newRetIndices;
15541570
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
15551571
rewriter, warpOp, {extractOp.getSource()}, {sourceDistType},
15561572
newRetIndices);
15571573
rewriter.setInsertionPointAfter(newWarpOp);
1558-
// Distributed sizes and offsets must be adjusted.
1559-
SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1560-
extractOp.getSizes(), [](Attribute attr) { return attr; });
1561-
SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
1562-
extractOp.getOffsets(), [](Attribute attr) { return attr; });
1563-
// Update the distributed sizes to match the distributed type.
1564-
distributedSizes[distributedDim] =
1565-
rewriter.getI64IntegerAttr(distributedType.getDimSize(distributedDim));
1566-
// Update the distributed offsets to match round robin distribution (i.e.
1567-
// each lane owns data at `subgroupSize` stride given unit lane data).
1568-
distributedOffsets[distributedDim] =
1569-
rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
15701574
Value source = newWarpOp.getResult(newRetIndices[0]);
15711575
// Create a new extract op outside the warp op.
15721576
Value newExtractOp = vector::ExtractStridedSliceOp::create(
@@ -1602,87 +1606,97 @@ struct VectorInsertStridedSliceDistribution
16021606
auto insertResultType = cast<VectorType>(operand->get().getType());
16031607
auto destDistributedDims =
16041608
getDistributedDims(insertResultType, distributedType);
1605-
// Only single dimension distribution is supported.
1606-
if (destDistributedDims.size() != 1)
1607-
return rewriter.notifyMatchFailure(
1608-
warpOp, "Expecting source to be distributed in a single dimension.");
1609-
int64_t destDistributedDim = destDistributedDims[0];
1610-
1611-
VectorType srcType = insertOp.getSourceVectorType();
1612-
VectorType destType = insertOp.getDestVectorType();
1613-
// Currently we require that both source (kD) and dest (nD) vectors are
1614-
// distributed. This requires that distributedDim (d) is contained in the
1615-
// last k dims of the dest vector (d >= n - k).
1616-
int64_t sourceDistributedDim =
1617-
destDistributedDim - (destType.getRank() - srcType.getRank());
1618-
if (sourceDistributedDim < 0)
1619-
return rewriter.notifyMatchFailure(
1620-
insertOp, "distributed dimension must be in the last k (i.e. source "
1621-
"rank) dims of dest vector");
1622-
int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1623-
// Obtain the source and dest layouts.
1624-
auto destLayout = xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
1625-
auto sourceLayout =
1626-
xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
1627-
if (!destLayout || !sourceLayout ||
1628-
destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1629-
sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1630-
return rewriter.notifyMatchFailure(
1631-
warpOp, "the source or dest of insert_strided_slice op lacks "
1632-
"distribution layout");
1633-
// Because only single dimension distribution is supported, lane layout
1634-
// size at the distributed dim must be the subgroup size.
1635-
int subgroupSize =
1636-
destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1637-
// We require that source and dest lane data are all ones to ensure uniform
1638-
// round robin distribution.
1639-
auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1640-
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1641-
if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1642-
!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1643-
return rewriter.notifyMatchFailure(
1644-
warpOp, "Expecting unit lane data in source and dest layouts");
1645-
// Source distributed dim size must be multiples of subgroup size.
1646-
if (srcDistrDimSize % subgroupSize != 0)
1647-
return rewriter.notifyMatchFailure(
1648-
warpOp, "Distributed dimension size in source is not a multiple of "
1649-
"subgroup size.");
1650-
// Offsets in the distributed dimension must be multiples of subgroup size.
1651-
int64_t destDistrDimOffset =
1652-
cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1653-
if (destDistrDimOffset % subgroupSize != 0)
1654-
return rewriter.notifyMatchFailure(
1655-
warpOp,
1656-
"Offset along distributed dimension in dest is not a multiple of "
1657-
"subgroup size.");
1658-
// Do the distribution by yielding the source and dest of the insert op from
1659-
// the warp op and creating a new insert op outside the warp op.
1660-
VectorType sourceDistType =
1661-
getDistVecTypeBasedOnLaneLayout(sourceLayout,
1662-
insertOp.getSourceVectorType())
1663-
.value();
1664-
VectorType destDistType = getDistVecTypeBasedOnLaneLayout(
1665-
destLayout, insertOp.getDestVectorType())
1666-
.value();
1667-
// Create a new warp op that yields the source and dest of the insert op.
1609+
// Collect updated offsets, source type and dest type. They may be updated
1610+
// later if the data is distributed to lanes (as opposed to being owned by
1611+
// all lanes uniformly).
1612+
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1613+
insertOp.getOffsets(), [](Attribute attr) { return attr; });
1614+
VectorType updatedSourceType = insertOp.getSourceVectorType();
1615+
VectorType updatedDestType = insertOp.getDestVectorType();
1616+
if (destDistributedDims.size() > 0) {
1617+
// Only single dimension distribution is supported.
1618+
if (destDistributedDims.size() != 1)
1619+
return rewriter.notifyMatchFailure(
1620+
warpOp,
1621+
"Expecting source to be distributed in a single dimension.");
1622+
int64_t destDistributedDim = destDistributedDims[0];
1623+
1624+
VectorType srcType = insertOp.getSourceVectorType();
1625+
VectorType destType = insertOp.getDestVectorType();
1626+
// Currently we require that both source (kD) and dest (nD) vectors are
1627+
// distributed. This requires that distributedDim (d) is contained in the
1628+
// last k dims of the dest vector (d >= n - k).
1629+
int64_t sourceDistributedDim =
1630+
destDistributedDim - (destType.getRank() - srcType.getRank());
1631+
if (sourceDistributedDim < 0)
1632+
return rewriter.notifyMatchFailure(
1633+
insertOp,
1634+
"distributed dimension must be in the last k (i.e. source "
1635+
"rank) dims of dest vector");
1636+
int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1637+
// Obtain the source and dest layouts.
1638+
auto destLayout =
1639+
xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
1640+
auto sourceLayout =
1641+
xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
1642+
if (!destLayout || !sourceLayout ||
1643+
destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1644+
sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1645+
return rewriter.notifyMatchFailure(
1646+
warpOp, "the source or dest of insert_strided_slice op lacks "
1647+
"distribution layout");
1648+
// Because only single dimension distribution is supported, lane layout
1649+
// size at the distributed dim must be the subgroup size.
1650+
int subgroupSize =
1651+
destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1652+
// We require that source and dest lane data are all ones to ensure
1653+
// uniform round robin distribution.
1654+
auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1655+
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1656+
if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
1657+
!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
1658+
return rewriter.notifyMatchFailure(
1659+
warpOp, "Expecting unit lane data in source and dest layouts");
1660+
// Source distributed dim size must be multiples of subgroup size.
1661+
if (srcDistrDimSize % subgroupSize != 0)
1662+
return rewriter.notifyMatchFailure(
1663+
warpOp, "Distributed dimension size in source is not a multiple of "
1664+
"subgroup size.");
1665+
// Offsets in the distributed dimension must be multiples of subgroup
1666+
// size.
1667+
int64_t destDistrDimOffset =
1668+
cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1669+
if (destDistrDimOffset % subgroupSize != 0)
1670+
return rewriter.notifyMatchFailure(
1671+
warpOp,
1672+
"Offset along distributed dimension in dest is not a multiple of "
1673+
"subgroup size.");
1674+
// Update the source and dest types based on their layouts.
1675+
updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1676+
sourceLayout, insertOp.getSourceVectorType())
1677+
.value();
1678+
updatedDestType = getDistVecTypeBasedOnLaneLayout(
1679+
destLayout, insertOp.getDestVectorType())
1680+
.value();
1681+
// Update the distributed offsets to match round robin distribution (i.e.
1682+
// each lane owns data at `subgroupSize` stride given unit lane data).
1683+
updatedOffsets[destDistributedDim] =
1684+
rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1685+
}
1686+
// Do the distribution by yielding the source and dest of the insert op
1687+
// from the warp op and creating a new insert op outside the warp op.
16681688
SmallVector<size_t> newRetIndices;
16691689
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
16701690
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1671-
{sourceDistType, destDistType}, newRetIndices);
1691+
{updatedSourceType, updatedDestType}, newRetIndices);
16721692
rewriter.setInsertionPointAfter(newWarpOp);
1673-
// Distributed offsets must be adjusted.
1674-
SmallVector<Attribute> distributedOffsets = llvm::map_to_vector(
1675-
insertOp.getOffsets(), [](Attribute attr) { return attr; });
1676-
// Update the distributed offsets to match round robin distribution (i.e.
1677-
// each lane owns data at `subgroupSize` stride given unit lane data).
1678-
distributedOffsets[destDistributedDim] =
1679-
rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1693+
16801694
Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
16811695
Value dest = newWarpOp.getResult(newRetIndices[1]);
16821696
// Create a new insert op outside the warp op.
16831697
Value newInsertOp = vector::InsertStridedSliceOp::create(
1684-
rewriter, insertOp.getLoc(), destDistType, valueToStore, dest,
1685-
ArrayAttr::get(rewriter.getContext(), distributedOffsets),
1698+
rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1699+
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
16861700
insertOp.getStrides());
16871701
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
16881702
newInsertOp);

0 commit comments

Comments
 (0)