Skip to content

Commit 7d3dde7

Browse files
committed
change variable name
1 parent bbd38af commit 7d3dde7

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,14 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
296296
}
297297
};
298298

299-
// Utility function to compute distributed offsets for subgroup operations.
299+
// Utility function to compute global offsets for subgroup operations.
300300
// Returns a vector of new offsets for each subgroup, given the original op's
301301
// offsets and subgroup relative offsets.
302-
static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets(
303-
Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304-
ArrayRef<OpFoldResult> wgOffsets, ConversionPatternRewriter &rewriter) {
305-
SmallVector<SmallVector<OpFoldResult>> distributedOffsets;
302+
static SmallVector<SmallVector<OpFoldResult>>
303+
computeGlobalOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
304+
ArrayRef<OpFoldResult> wgOffsets,
305+
ConversionPatternRewriter &rewriter) {
306+
SmallVector<SmallVector<OpFoldResult>> globalOffsets;
306307
Location loc = op->getLoc();
307308
for (const auto &sgOffsets : sgOffsetsList) {
308309
SmallVector<OpFoldResult> newOffsets;
@@ -314,9 +315,9 @@ static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets(
314315
getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
315316
newOffsets.push_back(add);
316317
}
317-
distributedOffsets.push_back(std::move(newOffsets));
318+
globalOffsets.push_back(std::move(newOffsets));
318319
}
319-
return distributedOffsets;
320+
return globalOffsets;
320321
}
321322

322323
// Utility function to get sgShape, sgOffsetList for a given
@@ -408,12 +409,12 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
408409
SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
409410

410411
// Calculate the global offsets
411-
auto distributedOffsets =
412-
computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
412+
auto globalOffsets =
413+
computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
413414

414415
SmallVector<Value> newLoadOps;
415416
for (auto [offsets, tdesc] :
416-
llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
417+
llvm::zip(globalOffsets, adaptor.getTensorDesc())) {
417418
VectorType newResTy = VectorType::get(
418419
sgShape,
419420
dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
@@ -449,11 +450,11 @@ struct WgToSgStoreNdOpWithOffset
449450
SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
450451

451452
// Calculate the global offsets
452-
auto distributedOffsets =
453-
computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
453+
auto globalOffsets =
454+
computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
454455

455456
for (auto [offsets, tdesc, value] : llvm::zip(
456-
distributedOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
457+
globalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
457458
rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
458459
op.getL1HintAttr(), op.getL2HintAttr(),
459460
op.getL3HintAttr());
@@ -483,11 +484,11 @@ struct WgToSgPrefetchNdOpWithOffset
483484
SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
484485

485486
// calculate the global offsets
486-
auto distributedOffsets =
487-
computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
487+
auto globalOffsets =
488+
computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
488489

489490
for (auto [offsets, tdesc] :
490-
llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
491+
llvm::zip(globalOffsets, adaptor.getTensorDesc())) {
491492
rewriter.create<xegpu::PrefetchNdOp>(
492493
op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
493494
op.getL3HintAttr());

0 commit comments

Comments
 (0)