@@ -64,27 +64,6 @@ namespace {
6464static constexpr unsigned regularPatternBenefit = 1 ;
6565static constexpr unsigned highPatternBenefit = 2 ;
6666
67- // / Helper function to compute the effective lane layout from a
68- // / DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr.
69- static SmallVector<int64_t >
70- computeEffectiveLaneLayout (const xegpu::DistributeLayoutAttr layout) {
71- SmallVector<int64_t > effectiveLaneLayout;
72- // If the layout is a slice, we need to get effective lane layout by removing
73- // sliced dims.
74- if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
75- ArrayRef<int64_t > slicedDims = sliceAttr.flatten ().getDims ().asArrayRef ();
76- llvm::DenseSet<int64_t > lookUp (slicedDims.begin (), slicedDims.end ());
77- for (auto [i, dim] :
78- llvm::enumerate (sliceAttr.getParent ().getLaneLayoutAsInt ())) {
79- if (!lookUp.contains (i))
80- effectiveLaneLayout.push_back (dim);
81- }
82- } else {
83- effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt ();
84- }
85- return effectiveLaneLayout;
86- }
87-
8867// / Helper function to get distributed vector type for a source vector type
8968// / according to the lane_layout. We simply divide each dimension of tensor
9069// / descriptor shape by corresponding lane_layout dimension. If
@@ -105,9 +84,11 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
10584 return failure ();
10685 assert ((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
10786 " Expecting a valid layout." );
108- SmallVector<int64_t > effectiveLaneLayout = computeEffectiveLaneLayout (layout);
87+ SmallVector<int64_t > effectiveLaneLayout =
88+ xegpu::computeEffectiveLaneLayout (layout);
10989
110- assert (originalType.getShape ().size () >= effectiveLaneLayout.size () &&
90+ assert (static_cast <size_t >(originalType.getRank ()) >=
91+ effectiveLaneLayout.size () &&
11192 " Rank of the original vector type should be greater or equal to the "
11293 " size of the lane layout to distribute the vector type." );
11394 SmallVector<int64_t > distributedShape (originalType.getShape ());
@@ -1369,7 +1350,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
13691350 vecRank, {static_cast <unsigned int >(vecRank - 1 )}, val.getContext ());
13701351 SmallVector<unsigned int > distributedDims;
13711352 // Get the distributed dimensions based on the layout.
1372- SmallVector<int64_t > laneLayout = computeEffectiveLaneLayout (layout);
1353+ SmallVector<int64_t > laneLayout = xegpu:: computeEffectiveLaneLayout (layout);
13731354 for (unsigned i = 0 ; i < laneLayout.size (); ++i) {
13741355 if (laneLayout[i] > 1 )
13751356 distributedDims.push_back (i);
0 commit comments