@@ -296,13 +296,14 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
296296 }
297297};
298298
299- // Utility function to compute distributed offsets for subgroup operations.
299+ // Utility function to compute global offsets for subgroup operations.
300300// Returns a vector of new offsets for each subgroup, given the original op's
301301// offsets and subgroup relative offsets.
302- static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets (
303- Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304- ArrayRef<OpFoldResult> wgOffsets, ConversionPatternRewriter &rewriter) {
305- SmallVector<SmallVector<OpFoldResult>> distributedOffsets;
302+ static SmallVector<SmallVector<OpFoldResult>>
303+ computeGlobalOffsets (Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304+ ArrayRef<OpFoldResult> wgOffsets,
305+ ConversionPatternRewriter &rewriter) {
306+ SmallVector<SmallVector<OpFoldResult>> globalOffsets;
306307 Location loc = op->getLoc ();
307308 for (const auto &sgOffsets : sgOffsetsList) {
308309 SmallVector<OpFoldResult> newOffsets;
@@ -314,9 +315,9 @@ static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets(
314315 getValueOrCreateConstantIndexOp (rewriter, loc, wgOffsets[idx]));
315316 newOffsets.push_back (add);
316317 }
317- distributedOffsets .push_back (std::move (newOffsets));
318+ globalOffsets .push_back (std::move (newOffsets));
318319 }
319- return distributedOffsets ;
320+ return globalOffsets ;
320321}
321322
322323// Utility function to get sgShape, sgOffsetList for a given
@@ -408,12 +409,12 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
408409 SmallVector<OpFoldResult> wgOffsets = getWgOffsets (op, rewriter);
409410
410411 // Calculate the global offsets
411- auto distributedOffsets =
412- computeDistributedOffsets (op, sgOffsetList, wgOffsets, rewriter);
412+ auto globalOffsets =
413+ computeGlobalOffsets (op, sgOffsetList, wgOffsets, rewriter);
413414
414415 SmallVector<Value> newLoadOps;
415416 for (auto [offsets, tdesc] :
416- llvm::zip (distributedOffsets , adaptor.getTensorDesc ())) {
417+ llvm::zip (globalOffsets , adaptor.getTensorDesc ())) {
417418 VectorType newResTy = VectorType::get (
418419 sgShape,
419420 dyn_cast<xegpu::TensorDescType>(tdesc.getType ()).getElementType ());
@@ -449,11 +450,11 @@ struct WgToSgStoreNdOpWithOffset
449450 SmallVector<OpFoldResult> wgOffsets = getWgOffsets (op, rewriter);
450451
451452 // Calculate the global offsets
452- auto distributedOffsets =
453- computeDistributedOffsets (op, sgOffsetList, wgOffsets, rewriter);
453+ auto globalOffsets =
454+ computeGlobalOffsets (op, sgOffsetList, wgOffsets, rewriter);
454455
455456 for (auto [offsets, tdesc, value] : llvm::zip (
456- distributedOffsets , adaptor.getTensorDesc (), adaptor.getValue ())) {
457+ globalOffsets , adaptor.getTensorDesc (), adaptor.getValue ())) {
457458 rewriter.create <xegpu::StoreNdOp>(op.getLoc (), value, tdesc, offsets,
458459 op.getL1HintAttr (), op.getL2HintAttr (),
459460 op.getL3HintAttr ());
@@ -483,11 +484,11 @@ struct WgToSgPrefetchNdOpWithOffset
483484 SmallVector<OpFoldResult> wgOffsets = getWgOffsets (op, rewriter);
484485
485486 // calculate the global offsets
486- auto distributedOffsets =
487- computeDistributedOffsets (op, sgOffsetList, wgOffsets, rewriter);
487+ auto globalOffsets =
488+ computeGlobalOffsets (op, sgOffsetList, wgOffsets, rewriter);
488489
489490 for (auto [offsets, tdesc] :
490- llvm::zip (distributedOffsets , adaptor.getTensorDesc ())) {
491+ llvm::zip (globalOffsets , adaptor.getTensorDesc ())) {
491492 rewriter.create <xegpu::PrefetchNdOp>(
492493 op.getLoc (), tdesc, offsets, op.getL1HintAttr (), op.getL2HintAttr (),
493494 op.getL3HintAttr ());
0 commit comments