@@ -174,6 +174,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
174174 return laneLayout[0 ] == uArch->getSubgroupSize () && laneLayout[1 ] == 1 ;
175175}
176176
177+ // / Given a vector type and its distributed vector type, return the list of
178+ // / dimensions that are distributed.
179+ static SmallVector<int64_t > getDistributedDims (VectorType originalType,
180+ VectorType distributedType) {
181+ assert (originalType.getRank () == distributedType.getRank () &&
182+ " sequential and distributed vector types must have the same rank" );
183+ SmallVector<int64_t > distributedDims;
184+ for (int64_t i = 0 ; i < originalType.getRank (); ++i) {
185+ if (distributedType.getDimSize (i) != originalType.getDimSize (i)) {
186+ distributedDims.push_back (i);
187+ }
188+ }
189+ return distributedDims;
190+ }
191+
177192// / Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
178193// / of the original GPUFuncOp to the new GPUFuncOp such that entire body is
179194// / contained within a WarpExecuteOnLane0Op.
@@ -1469,6 +1484,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
14691484 }
14701485};
14711486
1487+ // Distribute a `vector.extract_strided_slice` op feeding into yield op of an
1488+ // enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1489+ // advanced cases where the distributed dimension is partially extracted and
1490+ // currently not supported by the generic vector distribution patterns.
1491+ struct VectorExtractStridedSliceDistribution
1492+ : public gpu::WarpDistributionPattern {
1493+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1494+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
1495+ PatternRewriter &rewriter) const override {
1496+ OpOperand *operand =
1497+ getWarpResult (warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1498+ if (!operand)
1499+ return failure ();
1500+ auto extractOp =
1501+ cast<vector::ExtractStridedSliceOp>(operand->get ().getDefiningOp ());
1502+ unsigned operandIdx = operand->getOperandNumber ();
1503+ auto distributedType =
1504+ cast<VectorType>(warpOp.getResult (operandIdx).getType ());
1505+ // Find the distributed dimensions.
1506+ auto extractResultType = cast<VectorType>(operand->get ().getType ());
1507+ auto distributedDims =
1508+ getDistributedDims (extractResultType, distributedType);
1509+ // Collect updated source type, sizes and offsets. They may be adjusted
1510+ // later if the data is distributed to lanes (as opposed to being owned by
1511+ // all lanes uniformly).
1512+ VectorType updatedSourceType = extractOp.getSourceVectorType ();
1513+ SmallVector<Attribute> updatedSizes = llvm::map_to_vector (
1514+ extractOp.getSizes (), [](Attribute attr) { return attr; });
1515+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector (
1516+ extractOp.getOffsets (), [](Attribute attr) { return attr; });
1517+ // If the result is distributed, it must be distributed in exactly one
1518+ // dimension. In this case, we adjust the sourceDistType, distributedSizes
1519+ // and distributedOffsets accordingly.
1520+ if (distributedDims.size () > 0 ) {
1521+ if (distributedDims.size () != 1 )
1522+ return rewriter.notifyMatchFailure (
1523+ warpOp, " Source can not be distributed in multiple dimensions." );
1524+ int64_t distributedDim = distributedDims[0 ];
1525+ int sourceDistrDimSize =
1526+ extractOp.getSourceVectorType ().getShape ()[distributedDim];
1527+ auto sourceLayout =
1528+ xegpu::getDistributeLayoutAttr (extractOp->getOpOperand (0 ));
1529+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt ().empty ())
1530+ return rewriter.notifyMatchFailure (
1531+ warpOp, " the source of extract_strided_slice op lacks distribution "
1532+ " layout" );
1533+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt ();
1534+ // Because only single dimension distribution is supported, lane layout
1535+ // size at the distributed dim must be the subgroup size.
1536+ int subgroupSize = sourceLaneLayout[distributedDim];
1537+ // Check if the source size in the distributed dimension is a multiple of
1538+ // subgroup size.
1539+ if (sourceDistrDimSize % subgroupSize != 0 )
1540+ return rewriter.notifyMatchFailure (
1541+ warpOp,
1542+ " Source size along distributed dimension is not a multiple of "
1543+ " subgroup size." );
1544+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt ();
1545+ // We expect lane data to be all ones in this case.
1546+ if (!llvm::all_of (sourceLaneData, [](int64_t v) { return v == 1 ; }))
1547+ return rewriter.notifyMatchFailure (
1548+ warpOp, " Expecting unit lane data in source layout" );
1549+ // The offsets in the distributed dimention must be a multiple of subgroup
1550+ // size.
1551+ int64_t distrDimOffset =
1552+ cast<IntegerAttr>(extractOp.getOffsets ()[distributedDim]).getInt ();
1553+ if (distrDimOffset % subgroupSize != 0 )
1554+ return rewriter.notifyMatchFailure (
1555+ warpOp, " Offset along distributed dimension "
1556+ " is not a multiple of subgroup size." );
1557+ updatedSourceType = getDistVecTypeBasedOnLaneLayout (
1558+ sourceLayout, extractOp.getSourceVectorType ())
1559+ .value ();
1560+ // Update the distributed sizes to match the distributed type.
1561+ updatedSizes[distributedDim] = rewriter.getI64IntegerAttr (
1562+ distributedType.getDimSize (distributedDim));
1563+ // Update the distributed offsets to match round robin distribution (i.e.
1564+ // each lane owns data at `subgroupSize` stride given unit lane data).
1565+ updatedOffsets[distributedDim] =
1566+ rewriter.getI64IntegerAttr (distrDimOffset / subgroupSize);
1567+ }
1568+ // Do the distribution by yielding the source of the extract op from
1569+ // the warp op and creating a new extract op outside the warp op.
1570+ SmallVector<size_t > newRetIndices;
1571+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1572+ rewriter, warpOp, {extractOp.getSource ()}, {updatedSourceType},
1573+ newRetIndices);
1574+ rewriter.setInsertionPointAfter (newWarpOp);
1575+ Value source = newWarpOp.getResult (newRetIndices[0 ]);
1576+ // Create a new extract op outside the warp op.
1577+ Value newExtractOp = vector::ExtractStridedSliceOp::create (
1578+ rewriter, extractOp.getLoc (), distributedType, source,
1579+ ArrayAttr::get (rewriter.getContext (), updatedOffsets),
1580+ ArrayAttr::get (rewriter.getContext (), updatedSizes),
1581+ extractOp.getStrides ());
1582+ rewriter.replaceAllUsesWith (newWarpOp.getResult (operandIdx), newExtractOp);
1583+ return success ();
1584+ }
1585+ };
1586+
1587+ // / Distribute a `vector.insert_strided_slice` op feeding into yield op of an
1588+ // / enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
1589+ // / advanced cases where the distributed dimension is partially inserted and
1590+ // / currently not supported by the generic vector distribution patterns.
1591+ struct VectorInsertStridedSliceDistribution
1592+ : public gpu::WarpDistributionPattern {
1593+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1594+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
1595+ PatternRewriter &rewriter) const override {
1596+ OpOperand *operand =
1597+ getWarpResult (warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1598+ if (!operand)
1599+ return failure ();
1600+ unsigned int operandNumber = operand->getOperandNumber ();
1601+ auto insertOp =
1602+ operand->get ().getDefiningOp <vector::InsertStridedSliceOp>();
1603+ auto distributedType =
1604+ cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1605+ // Find the distributed dimensions of the dest vector.
1606+ auto insertResultType = cast<VectorType>(operand->get ().getType ());
1607+ auto destDistributedDims =
1608+ getDistributedDims (insertResultType, distributedType);
1609+ // Collect updated offsets, source type and dest type. They may be adjusted
1610+ // later if the data is distributed to lanes (as opposed to being owned by
1611+ // all lanes uniformly).
1612+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector (
1613+ insertOp.getOffsets (), [](Attribute attr) { return attr; });
1614+ VectorType updatedSourceType = insertOp.getSourceVectorType ();
1615+ VectorType updatedDestType = insertOp.getDestVectorType ();
1616+ if (destDistributedDims.size () > 0 ) {
1617+ // Only single dimension distribution is supported.
1618+ if (destDistributedDims.size () != 1 )
1619+ return rewriter.notifyMatchFailure (
1620+ warpOp,
1621+ " Expecting source to be distributed in a single dimension." );
1622+ int64_t destDistributedDim = destDistributedDims[0 ];
1623+
1624+ VectorType srcType = insertOp.getSourceVectorType ();
1625+ VectorType destType = insertOp.getDestVectorType ();
1626+ // Currently we require that both source (kD) and dest (nD) vectors are
1627+ // distributed. This requires that distributedDim (d) is contained in the
1628+ // last k dims of the dest vector (d >= n - k).
1629+ int64_t sourceDistributedDim =
1630+ destDistributedDim - (destType.getRank () - srcType.getRank ());
1631+ if (sourceDistributedDim < 0 )
1632+ return rewriter.notifyMatchFailure (
1633+ insertOp,
1634+ " distributed dimension must be in the last k (i.e. source "
1635+ " rank) dims of dest vector" );
1636+ int64_t srcDistrDimSize = srcType.getDimSize (sourceDistributedDim);
1637+ // Obtain the source and dest layouts.
1638+ auto destLayout =
1639+ xegpu::getDistributeLayoutAttr (insertOp->getOpOperand (1 ));
1640+ auto sourceLayout =
1641+ xegpu::getDistributeLayoutAttr (insertOp->getOpOperand (0 ));
1642+ if (!destLayout || !sourceLayout ||
1643+ destLayout.getEffectiveLaneLayoutAsInt ().empty () ||
1644+ sourceLayout.getEffectiveLaneLayoutAsInt ().empty ())
1645+ return rewriter.notifyMatchFailure (
1646+ warpOp, " the source or dest of insert_strided_slice op lacks "
1647+ " distribution layout" );
1648+ // Because only single dimension distribution is supported, lane layout
1649+ // size at the distributed dim must be the subgroup size.
1650+ int subgroupSize =
1651+ destLayout.getEffectiveLaneLayoutAsInt ()[destDistributedDim];
1652+ // We require that source and dest lane data are all ones to ensure
1653+ // uniform round robin distribution.
1654+ auto destLaneData = destLayout.getEffectiveLaneDataAsInt ();
1655+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt ();
1656+ if (!llvm::all_of (destLaneData, [](int64_t v) { return v == 1 ; }) ||
1657+ !llvm::all_of (sourceLaneData, [](int64_t v) { return v == 1 ; }))
1658+ return rewriter.notifyMatchFailure (
1659+ warpOp, " Expecting unit lane data in source and dest layouts" );
1660+ // Source distributed dim size must be multiples of subgroup size.
1661+ if (srcDistrDimSize % subgroupSize != 0 )
1662+ return rewriter.notifyMatchFailure (
1663+ warpOp, " Distributed dimension size in source is not a multiple of "
1664+ " subgroup size." );
1665+ // Offsets in the distributed dimension must be multiples of subgroup
1666+ // size.
1667+ int64_t destDistrDimOffset =
1668+ cast<IntegerAttr>(insertOp.getOffsets ()[destDistributedDim]).getInt ();
1669+ if (destDistrDimOffset % subgroupSize != 0 )
1670+ return rewriter.notifyMatchFailure (
1671+ warpOp,
1672+ " Offset along distributed dimension in dest is not a multiple of "
1673+ " subgroup size." );
1674+ // Update the source and dest types based on their layouts.
1675+ updatedSourceType = getDistVecTypeBasedOnLaneLayout (
1676+ sourceLayout, insertOp.getSourceVectorType ())
1677+ .value ();
1678+ updatedDestType = getDistVecTypeBasedOnLaneLayout (
1679+ destLayout, insertOp.getDestVectorType ())
1680+ .value ();
1681+ // Update the distributed offsets to match round robin distribution (i.e.
1682+ // each lane owns data at `subgroupSize` stride given unit lane data).
1683+ updatedOffsets[destDistributedDim] =
1684+ rewriter.getI64IntegerAttr (destDistrDimOffset / subgroupSize);
1685+ }
1686+ // Do the distribution by yielding the source and dest of the insert op
1687+ // from the warp op and creating a new insert op outside the warp op.
1688+ SmallVector<size_t > newRetIndices;
1689+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1690+ rewriter, warpOp, {insertOp.getValueToStore (), insertOp.getDest ()},
1691+ {updatedSourceType, updatedDestType}, newRetIndices);
1692+ rewriter.setInsertionPointAfter (newWarpOp);
1693+
1694+ Value valueToStore = newWarpOp.getResult (newRetIndices[0 ]);
1695+ Value dest = newWarpOp.getResult (newRetIndices[1 ]);
1696+ // Create a new insert op outside the warp op.
1697+ Value newInsertOp = vector::InsertStridedSliceOp::create (
1698+ rewriter, insertOp.getLoc (), updatedDestType, valueToStore, dest,
1699+ ArrayAttr::get (rewriter.getContext (), updatedOffsets),
1700+ insertOp.getStrides ());
1701+ rewriter.replaceAllUsesWith (newWarpOp.getResult (operandNumber),
1702+ newInsertOp);
1703+ return success ();
1704+ }
1705+ };
1706+
14721707// / Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
14731708// / enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
14741709// / outside of the warp op.
@@ -1626,9 +1861,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
16261861 MemrefExtractAlignedPointerAsIndexDistribution>(
16271862 patterns.getContext (),
16281863 /* pattern benefit=*/ regularPatternBenefit);
1629- patterns.add <VectorShapeCastDistribution>(
1630- patterns.getContext (),
1631- /* pattern benefit=*/ highPatternBenefit);
1864+ // For following patterns, we need to override the regular vector distribution
1865+ // patterns. Therefore, assign higher benefit.
1866+ patterns
1867+ .add <VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
1868+ VectorInsertStridedSliceDistribution>(
1869+ patterns.getContext (),
1870+ /* pattern benefit=*/ highPatternBenefit);
16321871}
16331872
16341873void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns (
0 commit comments