@@ -179,26 +179,20 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
179179 SmallVector<OpFoldResult> offset = op.getMixedOffsets ();
180180
181181 for (auto tdescOffset : *maybeTdescOffsets) {
182- SmallVector<OpFoldResult> newOffsets = llvm::map_to_vector (
183- llvm::zip_longest (tdescOffset, offset),
184- [&](const auto &t) -> OpFoldResult {
185- std::optional<Value> off = std::get<0 >(t);
186- std::optional<OpFoldResult> old = std::get<1 >(t);
187- if (!off.has_value ())
188- return *old;
189-
190- if (!old.has_value () || isZeroInteger (*old))
191- return *off;
192-
193- return rewriter.createOrFold <index::AddOp>(
194- loc, *off,
195- getValueOrCreateConstantIndexOp (rewriter, loc, *old));
196- });
197-
198- auto newCreateNdOp = xegpu::CreateNdDescOp::create (
182+ SmallVector<OpFoldResult> newOffsets;
183+ size_t rank = tdescOffset.size ();
184+ for (size_t i = 0 ; i < rank; i++) {
185+ size_t idx = offset.size () - rank + i;
186+ Value newOff = rewriter.createOrFold <index::AddOp>(
187+ loc, tdescOffset[i],
188+ getValueOrCreateConstantIndexOp (rewriter, loc, offset[idx]));
189+ newOffsets.push_back (newOff);
190+ }
191+
192+ auto newOp = xegpu::CreateNdDescOp::create (
199193 rewriter, loc, newTdescTy, op.getSource (), newOffsets,
200194 op.getMixedSizes (), op.getMixedStrides ());
201- newCreateNdOps.push_back (newCreateNdOp );
195+ newCreateNdOps.push_back (newOp );
202196 }
203197 rewriter.replaceOpWithMultiple (op, {newCreateNdOps});
204198 return success ();
0 commit comments