@@ -125,39 +125,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
125125struct WgToSgCreateNdOp : public OpConversionPattern <xegpu::CreateNdDescOp> {
126126 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
127127
128- // Calculate offset for each subgroup
129- static SmallVector<OpFoldResult>
130- calculateGlobalOffsets (ConversionPatternRewriter &rewriter, Location loc,
131- const SmallVector<OpFoldResult> &originalOffsets,
132- const SmallVector<Value> &localOffset,
133- const SmallVector<int64_t > &distUnitBaseAddr,
134- const SmallVector<int64_t > &distUnitShape) {
135- assert (localOffset.size () == distUnitBaseAddr.size () &&
136- " localOffset and distUnitBaseAddr must have the same rank" );
137-
138- SmallVector<OpFoldResult> globalOffsets (originalOffsets.begin (),
139- originalOffsets.end ());
140- size_t rank = localOffset.size ();
141- for (size_t i = 0 ; i < rank; ++i) {
142- size_t dimIdx = originalOffsets.size () - rank + i;
143- Value constOffset =
144- arith::ConstantIndexOp::create (rewriter, loc, distUnitBaseAddr[i]);
145- Value offset =
146- rewriter.createOrFold <index::AddOp>(loc, localOffset[i], constOffset);
147- Value modValue =
148- arith::ConstantIndexOp::create (rewriter, loc, distUnitShape[i]);
149- Value offsetMod =
150- rewriter.createOrFold <index::RemUOp>(loc, offset, modValue);
151- Value origOffset = getValueOrCreateConstantIndexOp (
152- rewriter, loc, originalOffsets[dimIdx]);
153- Value globalOffset =
154- rewriter.createOrFold <index::AddOp>(loc, origOffset, offsetMod);
155- globalOffsets[dimIdx] = globalOffset;
156- }
157-
158- return globalOffsets;
159- }
160-
161128 LogicalResult
162129 matchAndRewrite (xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
163130 ConversionPatternRewriter &rewriter) const override {
@@ -177,28 +144,14 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
177144 return rewriter.notifyMatchFailure (
178145 op, " sgLayout attribute is required in layout" );
179146
180- SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
181-
182- // TODO : Handle order attribute
183147 // Get the subgroup ID
184- auto linearSgId =
148+ Value linearSgId =
185149 gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
186150
187- // Create constants for layout dimensions
188- SmallVector<Value> sgLayoutDim (sgLayout.size ());
189- SmallVector<Value> sgDataDim (sgShape.size ());
190-
191- for (size_t i = 0 ; i < sgLayout.size (); i++) {
192- sgLayoutDim[i] =
193- arith::ConstantIndexOp::create (rewriter, loc, sgLayout[i]);
194- sgDataDim[i] = arith::ConstantIndexOp::create (rewriter, loc, sgShape[i]);
195- }
196-
197151 int64_t startOfRange = -1 , endOfRange = -1 ;
198152 bool sgIdRangeSpecified =
199153 isSgIdRangeSpecified (op, startOfRange, endOfRange);
200154
201- Value adjustedSgId = linearSgId;
202155 if (sgIdRangeSpecified) {
203156 int64_t sgCount = endOfRange - startOfRange;
204157 if (computeProduct (sgLayout) != sgCount)
@@ -208,22 +161,24 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
208161 // sg id
209162 Value startOfRangeVal =
210163 rewriter.create <arith::ConstantIndexOp>(loc, startOfRange);
211- adjustedSgId =
164+ linearSgId =
212165 rewriter.createOrFold <index::SubOp>(loc, linearSgId, startOfRangeVal);
213166 }
214167
215- auto tdescOffsets = layout.getOffsets (rewriter, loc, adjustedSgId, wgShape);
216- if (failed (tdescOffsets))
168+ auto maybeTdescOffsets =
169+ layout.getOffsets (rewriter, loc, linearSgId, wgShape);
170+ if (failed (maybeTdescOffsets))
217171 return failure ();
218172
173+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
219174 xegpu::TensorDescType newTdescTy =
220175 xegpu::TensorDescType::get (ctx, sgShape, elemTy, tdescTy.getEncoding (),
221176 layout.dropSgLayoutAndData ());
222177
223178 SmallVector<Value> newCreateNdOps;
224179 SmallVector<OpFoldResult> offset = op.getMixedOffsets ();
225180
226- for (auto tdescOffset : *tdescOffsets ) {
181+ for (auto tdescOffset : *maybeTdescOffsets ) {
227182 SmallVector<OpFoldResult> newOffsets = llvm::map_to_vector (
228183 llvm::zip_longest (tdescOffset, offset),
229184 [&](const auto &t) -> OpFoldResult {
0 commit comments