@@ -1505,68 +1505,72 @@ struct VectorExtractStridedSliceDistribution
15051505 auto extractResultType = cast<VectorType>(operand->get ().getType ());
15061506 auto distributedDims =
15071507 getDistributedDims (extractResultType, distributedType);
1508- // Only single dimension distribution is supported.
1509- if (distributedDims.size () != 1 )
1510- return rewriter.notifyMatchFailure (
1511- warpOp, " Expecting source to be distributed in a single dimension." );
1512- int64_t distributedDim = distributedDims[0 ];
1513- int sourceDistrDimSize =
1514- extractOp.getSourceVectorType ().getShape ()[distributedDim];
1515-
1516- auto sourceLayout =
1517- xegpu::getDistributeLayoutAttr (extractOp->getOpOperand (0 ));
1518- if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt ().empty ())
1519- return rewriter.notifyMatchFailure (
1520- warpOp, " the source of extract_strided_slice op lacks distribution "
1521- " layout" );
1522- auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt ();
1523- // Because only single dimension distribution is supported, lane layout size
1524- // at the distributed dim must be the subgroup size.
1525- int subgroupSize = sourceLaneLayout[distributedDim];
1526- // Check if the source size in the distributed dimension is a multiple of
1527- // subgroup size.
1528- if (sourceDistrDimSize % subgroupSize != 0 )
1529- return rewriter.notifyMatchFailure (
1530- warpOp,
1531- " Source size along distributed dimension is not a multiple of "
1532- " subgroup size." );
1533- auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt ();
1534- // We expect lane data to be all ones in this case.
1535- if (!llvm::all_of (sourceLaneData, [](int64_t v) { return v == 1 ; }))
1536- return rewriter.notifyMatchFailure (
1537- warpOp, " Expecting unit lane data in source layout" );
1538- // The offsets in the distributed dimention must be a multiple of subgroup
1539- // size.
1540- int64_t distrDimOffset =
1541- cast<IntegerAttr>(extractOp.getOffsets ()[distributedDim]).getInt ();
1542- if (distrDimOffset % subgroupSize != 0 )
1543- return rewriter.notifyMatchFailure (warpOp,
1544- " Offset along distributed dimension "
1545- " is not a multiple of subgroup size." );
1546- // Do the distribution by yielding the source of the extract op from
1547- // the warp op and creating a new extract op outside the warp op.
1548- VectorType sourceDistType =
1549- getDistVecTypeBasedOnLaneLayout (sourceLayout,
1550- extractOp.getSourceVectorType ())
1551- .value ();
1508+ // Source distributed type must be adjusted for the distributed case.
1509+ VectorType sourceDistType = extractOp.getSourceVectorType ();
1510+ // Distributed sizes and offsets must be adjusted for distributed case.
1511+ SmallVector<Attribute> distributedSizes = llvm::map_to_vector (
1512+ extractOp.getSizes (), [](Attribute attr) { return attr; });
1513+ SmallVector<Attribute> distributedOffsets = llvm::map_to_vector (
1514+ extractOp.getOffsets (), [](Attribute attr) { return attr; });
1515+ // If the result is distributed, it must be distributed in exactly one
1516+ // dimension. In this case, we adjust the sourceDistType, distributedSizes
1517+ // and distributedOffsets accordingly.
1518+ if (distributedDims.size () > 0 ) {
1519+ if (distributedDims.size () != 1 )
1520+ return rewriter.notifyMatchFailure (
1521+ warpOp, " Source can not be distributed in multiple dimensions." );
1522+ int64_t distributedDim = distributedDims[0 ];
1523+ int sourceDistrDimSize =
1524+ extractOp.getSourceVectorType ().getShape ()[distributedDim];
1525+ auto sourceLayout =
1526+ xegpu::getDistributeLayoutAttr (extractOp->getOpOperand (0 ));
1527+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt ().empty ())
1528+ return rewriter.notifyMatchFailure (
1529+ warpOp, " the source of extract_strided_slice op lacks distribution "
1530+ " layout" );
1531+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt ();
1532+ // Because only single dimension distribution is supported, lane layout
1533+ // size at the distributed dim must be the subgroup size.
1534+ int subgroupSize = sourceLaneLayout[distributedDim];
1535+ // Check if the source size in the distributed dimension is a multiple of
1536+ // subgroup size.
1537+ if (sourceDistrDimSize % subgroupSize != 0 )
1538+ return rewriter.notifyMatchFailure (
1539+ warpOp,
1540+ " Source size along distributed dimension is not a multiple of "
1541+ " subgroup size." );
1542+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt ();
1543+ // We expect lane data to be all ones in this case.
1544+ if (!llvm::all_of (sourceLaneData, [](int64_t v) { return v == 1 ; }))
1545+ return rewriter.notifyMatchFailure (
1546+ warpOp, " Expecting unit lane data in source layout" );
1547+ // The offsets in the distributed dimention must be a multiple of subgroup
1548+ // size.
1549+ int64_t distrDimOffset =
1550+ cast<IntegerAttr>(extractOp.getOffsets ()[distributedDim]).getInt ();
1551+ if (distrDimOffset % subgroupSize != 0 )
1552+ return rewriter.notifyMatchFailure (
1553+ warpOp, " Offset along distributed dimension "
1554+ " is not a multiple of subgroup size." );
1555+ // Do the distribution by yielding the source of the extract op from
1556+ // the warp op and creating a new extract op outside the warp op.
1557+ sourceDistType = getDistVecTypeBasedOnLaneLayout (
1558+ sourceLayout, extractOp.getSourceVectorType ())
1559+ .value ();
1560+ // Update the distributed sizes to match the distributed type.
1561+ distributedSizes[distributedDim] = rewriter.getI64IntegerAttr (
1562+ distributedType.getDimSize (distributedDim));
1563+ // Update the distributed offsets to match round robin distribution (i.e.
1564+ // each lane owns data at `subgroupSize` stride given unit lane data).
1565+ distributedOffsets[distributedDim] =
1566+ rewriter.getI64IntegerAttr (distrDimOffset / subgroupSize);
1567+ }
15521568 // Create a new warp op that yields the source of the extract op.
15531569 SmallVector<size_t > newRetIndices;
15541570 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
15551571 rewriter, warpOp, {extractOp.getSource ()}, {sourceDistType},
15561572 newRetIndices);
15571573 rewriter.setInsertionPointAfter (newWarpOp);
1558- // Distributed sizes and offsets must be adjusted.
1559- SmallVector<Attribute> distributedSizes = llvm::map_to_vector (
1560- extractOp.getSizes (), [](Attribute attr) { return attr; });
1561- SmallVector<Attribute> distributedOffsets = llvm::map_to_vector (
1562- extractOp.getOffsets (), [](Attribute attr) { return attr; });
1563- // Update the distributed sizes to match the distributed type.
1564- distributedSizes[distributedDim] =
1565- rewriter.getI64IntegerAttr (distributedType.getDimSize (distributedDim));
1566- // Update the distributed offsets to match round robin distribution (i.e.
1567- // each lane owns data at `subgroupSize` stride given unit lane data).
1568- distributedOffsets[distributedDim] =
1569- rewriter.getI64IntegerAttr (distrDimOffset / subgroupSize);
15701574 Value source = newWarpOp.getResult (newRetIndices[0 ]);
15711575 // Create a new extract op outside the warp op.
15721576 Value newExtractOp = vector::ExtractStridedSliceOp::create (
@@ -1602,87 +1606,97 @@ struct VectorInsertStridedSliceDistribution
16021606 auto insertResultType = cast<VectorType>(operand->get ().getType ());
16031607 auto destDistributedDims =
16041608 getDistributedDims (insertResultType, distributedType);
1605- // Only single dimension distribution is supported.
1606- if (destDistributedDims.size () != 1 )
1607- return rewriter.notifyMatchFailure (
1608- warpOp, " Expecting source to be distributed in a single dimension." );
1609- int64_t destDistributedDim = destDistributedDims[0 ];
1610-
1611- VectorType srcType = insertOp.getSourceVectorType ();
1612- VectorType destType = insertOp.getDestVectorType ();
1613- // Currently we require that both source (kD) and dest (nD) vectors are
1614- // distributed. This requires that distributedDim (d) is contained in the
1615- // last k dims of the dest vector (d >= n - k).
1616- int64_t sourceDistributedDim =
1617- destDistributedDim - (destType.getRank () - srcType.getRank ());
1618- if (sourceDistributedDim < 0 )
1619- return rewriter.notifyMatchFailure (
1620- insertOp, " distributed dimension must be in the last k (i.e. source "
1621- " rank) dims of dest vector" );
1622- int64_t srcDistrDimSize = srcType.getDimSize (sourceDistributedDim);
1623- // Obtain the source and dest layouts.
1624- auto destLayout = xegpu::getDistributeLayoutAttr (insertOp->getOpOperand (1 ));
1625- auto sourceLayout =
1626- xegpu::getDistributeLayoutAttr (insertOp->getOpOperand (0 ));
1627- if (!destLayout || !sourceLayout ||
1628- destLayout.getEffectiveLaneLayoutAsInt ().empty () ||
1629- sourceLayout.getEffectiveLaneLayoutAsInt ().empty ())
1630- return rewriter.notifyMatchFailure (
1631- warpOp, " the source or dest of insert_strided_slice op lacks "
1632- " distribution layout" );
1633- // Because only single dimension distribution is supported, lane layout
1634- // size at the distributed dim must be the subgroup size.
1635- int subgroupSize =
1636- destLayout.getEffectiveLaneLayoutAsInt ()[destDistributedDim];
1637- // We require that source and dest lane data are all ones to ensure uniform
1638- // round robin distribution.
1639- auto destLaneData = destLayout.getEffectiveLaneDataAsInt ();
1640- auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt ();
1641- if (!llvm::all_of (destLaneData, [](int64_t v) { return v == 1 ; }) ||
1642- !llvm::all_of (sourceLaneData, [](int64_t v) { return v == 1 ; }))
1643- return rewriter.notifyMatchFailure (
1644- warpOp, " Expecting unit lane data in source and dest layouts" );
1645- // Source distributed dim size must be multiples of subgroup size.
1646- if (srcDistrDimSize % subgroupSize != 0 )
1647- return rewriter.notifyMatchFailure (
1648- warpOp, " Distributed dimension size in source is not a multiple of "
1649- " subgroup size." );
1650- // Offsets in the distributed dimension must be multiples of subgroup size.
1651- int64_t destDistrDimOffset =
1652- cast<IntegerAttr>(insertOp.getOffsets ()[destDistributedDim]).getInt ();
1653- if (destDistrDimOffset % subgroupSize != 0 )
1654- return rewriter.notifyMatchFailure (
1655- warpOp,
1656- " Offset along distributed dimension in dest is not a multiple of "
1657- " subgroup size." );
1658- // Do the distribution by yielding the source and dest of the insert op from
1659- // the warp op and creating a new insert op outside the warp op.
1660- VectorType sourceDistType =
1661- getDistVecTypeBasedOnLaneLayout (sourceLayout,
1662- insertOp.getSourceVectorType ())
1663- .value ();
1664- VectorType destDistType = getDistVecTypeBasedOnLaneLayout (
1665- destLayout, insertOp.getDestVectorType ())
1666- .value ();
1667- // Create a new warp op that yields the source and dest of the insert op.
1609+ // Collect updated offsets, source type and dest type. They may be updated
1610+ // later if the data is distributed to lanes (as opposed to being owned by
1611+ // all lanes uniformly).
1612+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector (
1613+ insertOp.getOffsets (), [](Attribute attr) { return attr; });
1614+ VectorType updatedSourceType = insertOp.getSourceVectorType ();
1615+ VectorType updatedDestType = insertOp.getDestVectorType ();
1616+ if (destDistributedDims.size () > 0 ) {
1617+ // Only single dimension distribution is supported.
1618+ if (destDistributedDims.size () != 1 )
1619+ return rewriter.notifyMatchFailure (
1620+ warpOp,
1621+ " Expecting source to be distributed in a single dimension." );
1622+ int64_t destDistributedDim = destDistributedDims[0 ];
1623+
1624+ VectorType srcType = insertOp.getSourceVectorType ();
1625+ VectorType destType = insertOp.getDestVectorType ();
1626+ // Currently we require that both source (kD) and dest (nD) vectors are
1627+ // distributed. This requires that distributedDim (d) is contained in the
1628+ // last k dims of the dest vector (d >= n - k).
1629+ int64_t sourceDistributedDim =
1630+ destDistributedDim - (destType.getRank () - srcType.getRank ());
1631+ if (sourceDistributedDim < 0 )
1632+ return rewriter.notifyMatchFailure (
1633+ insertOp,
1634+ " distributed dimension must be in the last k (i.e. source "
1635+ " rank) dims of dest vector" );
1636+ int64_t srcDistrDimSize = srcType.getDimSize (sourceDistributedDim);
1637+ // Obtain the source and dest layouts.
1638+ auto destLayout =
1639+ xegpu::getDistributeLayoutAttr (insertOp->getOpOperand (1 ));
1640+ auto sourceLayout =
1641+ xegpu::getDistributeLayoutAttr (insertOp->getOpOperand (0 ));
1642+ if (!destLayout || !sourceLayout ||
1643+ destLayout.getEffectiveLaneLayoutAsInt ().empty () ||
1644+ sourceLayout.getEffectiveLaneLayoutAsInt ().empty ())
1645+ return rewriter.notifyMatchFailure (
1646+ warpOp, " the source or dest of insert_strided_slice op lacks "
1647+ " distribution layout" );
1648+ // Because only single dimension distribution is supported, lane layout
1649+ // size at the distributed dim must be the subgroup size.
1650+ int subgroupSize =
1651+ destLayout.getEffectiveLaneLayoutAsInt ()[destDistributedDim];
1652+ // We require that source and dest lane data are all ones to ensure
1653+ // uniform round robin distribution.
1654+ auto destLaneData = destLayout.getEffectiveLaneDataAsInt ();
1655+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt ();
1656+ if (!llvm::all_of (destLaneData, [](int64_t v) { return v == 1 ; }) ||
1657+ !llvm::all_of (sourceLaneData, [](int64_t v) { return v == 1 ; }))
1658+ return rewriter.notifyMatchFailure (
1659+ warpOp, " Expecting unit lane data in source and dest layouts" );
1660+ // Source distributed dim size must be multiples of subgroup size.
1661+ if (srcDistrDimSize % subgroupSize != 0 )
1662+ return rewriter.notifyMatchFailure (
1663+ warpOp, " Distributed dimension size in source is not a multiple of "
1664+ " subgroup size." );
1665+ // Offsets in the distributed dimension must be multiples of subgroup
1666+ // size.
1667+ int64_t destDistrDimOffset =
1668+ cast<IntegerAttr>(insertOp.getOffsets ()[destDistributedDim]).getInt ();
1669+ if (destDistrDimOffset % subgroupSize != 0 )
1670+ return rewriter.notifyMatchFailure (
1671+ warpOp,
1672+ " Offset along distributed dimension in dest is not a multiple of "
1673+ " subgroup size." );
1674+ // Update the source and dest types based on their layouts.
1675+ updatedSourceType = getDistVecTypeBasedOnLaneLayout (
1676+ sourceLayout, insertOp.getSourceVectorType ())
1677+ .value ();
1678+ updatedDestType = getDistVecTypeBasedOnLaneLayout (
1679+ destLayout, insertOp.getDestVectorType ())
1680+ .value ();
1681+ // Update the distributed offsets to match round robin distribution (i.e.
1682+ // each lane owns data at `subgroupSize` stride given unit lane data).
1683+ updatedOffsets[destDistributedDim] =
1684+ rewriter.getI64IntegerAttr (destDistrDimOffset / subgroupSize);
1685+ }
1686+ // Do the distribution by yielding the source and dest of the insert op
1687+ // from the warp op and creating a new insert op outside the warp op.
16681688 SmallVector<size_t > newRetIndices;
16691689 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
16701690 rewriter, warpOp, {insertOp.getValueToStore (), insertOp.getDest ()},
1671- {sourceDistType, destDistType }, newRetIndices);
1691+ {updatedSourceType, updatedDestType }, newRetIndices);
16721692 rewriter.setInsertionPointAfter (newWarpOp);
1673- // Distributed offsets must be adjusted.
1674- SmallVector<Attribute> distributedOffsets = llvm::map_to_vector (
1675- insertOp.getOffsets (), [](Attribute attr) { return attr; });
1676- // Update the distributed offsets to match round robin distribution (i.e.
1677- // each lane owns data at `subgroupSize` stride given unit lane data).
1678- distributedOffsets[destDistributedDim] =
1679- rewriter.getI64IntegerAttr (destDistrDimOffset / subgroupSize);
1693+
16801694 Value valueToStore = newWarpOp.getResult (newRetIndices[0 ]);
16811695 Value dest = newWarpOp.getResult (newRetIndices[1 ]);
16821696 // Create a new insert op outside the warp op.
16831697 Value newInsertOp = vector::InsertStridedSliceOp::create (
1684- rewriter, insertOp.getLoc (), destDistType , valueToStore, dest,
1685- ArrayAttr::get (rewriter.getContext (), distributedOffsets ),
1698+ rewriter, insertOp.getLoc (), updatedDestType , valueToStore, dest,
1699+ ArrayAttr::get (rewriter.getContext (), updatedOffsets ),
16861700 insertOp.getStrides ());
16871701 rewriter.replaceAllUsesWith (newWarpOp.getResult (operandNumber),
16881702 newInsertOp);
0 commit comments