Skip to content

Commit a569bc8

Browse files
committed
clang format fix
1 parent 47fe143 commit a569bc8

File tree

2 files changed

+55
-46
lines changed

2 files changed

+55
-46
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -407,37 +407,40 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
407407

408408
if (!tdescTy.isScattered())
409409
return failure();
410-
410+
411411
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
412412
if (!targetShape)
413413
return failure();
414-
414+
415415
SmallVector<int64_t> targetIndiceShape(*targetShape);
416416
int64_t originalChunkSize = tdescTy.getChunkSize();
417417
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
418418
if (originalChunkSize > 1)
419419
targetIndiceShape.pop_back();
420420

421421
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
422-
SmallVector<Type> convertedIndiceTypes =
422+
SmallVector<Type> convertedIndiceTypes =
423423
getUnrolledTypes(indiceVecTy, targetIndiceShape);
424-
SmallVector<Value> convertedIndiceVec =
424+
SmallVector<Value> convertedIndiceVec =
425425
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
426-
426+
427427
SmallVector<Value> newOps;
428428

429429
// more indices is need when chunkSize > 1. Since a big load from one
430430
// address could be break into multiple small loads.
431431
if (originalChunkSize > 1) {
432432
int64_t blockedChunkSize = targetShape->back();
433-
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
433+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
434434

435-
for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
435+
for (auto [indice, indiceType] :
436+
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
436437
for (int64_t i = 0; i < numNewChunks; ++i) {
437438
// Compute the offset
438-
Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunkSize);
439+
Value inc = rewriter.create<arith::ConstantIndexOp>(
440+
loc, i * blockedChunkSize);
439441
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
440-
Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
442+
Value offsetIndice =
443+
rewriter.create<arith::AddIOp>(loc, indice, incVec);
441444

442445
auto newOp = rewriter.create<xegpu::CreateDescOp>(
443446
loc, newTdescTy, op.getSource(), offsetIndice);
@@ -447,11 +450,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
447450
}
448451
} else {
449452
for (auto indice : convertedIndiceVec) {
450-
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
451-
op.getSource(), indice);
453+
auto newOp = rewriter.create<xegpu::CreateDescOp>(
454+
loc, newTdescTy, op.getSource(), indice);
452455
newOps.push_back(newOp);
453456
}
454-
}
457+
}
455458

456459
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
457460
rewriter.replaceOp(op, castOp);
@@ -471,11 +474,11 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
471474

472475
if (!tdescTy.isScattered())
473476
return failure();
474-
477+
475478
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
476479
if (!targetShape)
477480
return failure();
478-
481+
479482
SmallVector<int64_t> targetMaskShape(*targetShape);
480483
int64_t originalChunkSize = tdescTy.getChunkSize();
481484

@@ -489,29 +492,31 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
489492
SmallVector<Value> convertedTdescs = pack(
490493
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
491494

492-
SmallVector<Type> convertedMaskTypes;
493-
SmallVector<Value> convertedMasks;
495+
SmallVector<Type> convertedMaskTypes;
496+
SmallVector<Value> convertedMasks;
494497

495498
if (originalChunkSize > 1) {
496499
targetMaskShape.pop_back();
497500
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
498-
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
501+
SmallVector<Value> convertedMasks1D = pack(
502+
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
499503
int64_t blockedChunkSize = targetShape->back();
500-
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
504+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
501505

502506
for (auto mask : convertedMasks1D) {
503507
for (int64_t i = 0; i < numNewChunks; ++i) {
504508
convertedMasks.push_back(mask);
505509
}
506510
}
507-
// This is to handle the transpose effect when chunkSize > 1.
511+
// This is to handle the transpose effect when chunkSize > 1.
508512
if (targetShape && targetShape->size() > 1) {
509513
std::swap((*targetShape)[0], (*targetShape)[1]);
510514
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
511515
}
512516
} else {
513517
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
514-
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
518+
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
519+
loc, rewriter);
515520
}
516521

517522
SmallVector<Value> newOps;
@@ -521,7 +526,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
521526
op.getL2HintAttr(), op.getL3HintAttr());
522527
newOps.push_back(newOp);
523528
}
524-
529+
525530
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
526531
rewriter.replaceOp(op, castOp);
527532
return success();
@@ -576,38 +581,40 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
576581
int64_t originalChunkSize = tdescTy.getChunkSize();
577582

578583
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
579-
584+
580585
SmallVector<Type> convertedTdescTypes =
581586
getUnrolledTypes(tdescTy, *targetShape);
582587
SmallVector<Value> convertedTdescs = pack(
583588
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
584589

585-
SmallVector<Type> convertedMaskTypes;
586-
SmallVector<Value> convertedMasks;
590+
SmallVector<Type> convertedMaskTypes;
591+
SmallVector<Value> convertedMasks;
587592

588593
if (originalChunkSize > 1) {
589594
int64_t blockedChunkSize = targetShape->back();
590-
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
595+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
591596
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
592-
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
597+
SmallVector<Value> convertedMasks1D = pack(
598+
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
593599

594600
for (auto mask : convertedMasks1D) {
595601
for (int64_t i = 0; i < numNewChunks; ++i) {
596602
convertedMasks.push_back(mask);
597603
}
598604
}
599-
// This is to handle the transpose effect when chunkSize > 1.
605+
// This is to handle the transpose effect when chunkSize > 1.
600606
std::swap((*targetShape)[0], (*targetShape)[1]);
601607

602608
} else {
603609
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
604-
convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
610+
convertedMasks =
611+
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
605612
}
606613

607614
SmallVector<Type> convertedValTypes =
608615
getUnrolledTypes(valueTy, *targetShape);
609616
SmallVector<Value> convertedValues =
610-
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
617+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
611618

612619
for (size_t i = 0; i < convertedValues.size(); ++i) {
613620
Value v = convertedValues[i];
@@ -630,7 +637,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
630637
Location loc = op.getLoc();
631638
xegpu::TensorDescType tdescTy = op.getTensorDescType();
632639

633-
if (tdescTy.getRank() >2)
640+
if (tdescTy.getRank() > 2)
634641
return failure();
635642

636643
if (!tdescTy.isScattered())
@@ -652,12 +659,14 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
652659
SmallVector<Value> newOps;
653660
int64_t originalChunkSize = tdescTy.getChunkSize();
654661
if (originalChunkSize > 1) {
655-
SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
662+
SmallVector<int64_t> shape1D(targetShape->begin(),
663+
targetShape->end() - 1);
656664
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
657-
SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
665+
SmallVector<Value> convertedOffsetVec1D =
666+
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
658667

659668
int64_t blockedChunkSize = targetShape->back();
660-
int64_t numNewChunks = originalChunkSize/blockedChunkSize;
669+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
661670

662671
for (auto offset : convertedOffsetVec1D) {
663672
for (int64_t i = 0; i < numNewChunks; ++i) {
@@ -667,8 +676,9 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
667676

668677
} else {
669678
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
670-
convertedOffsetVec = pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
671-
}
679+
convertedOffsetVec =
680+
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
681+
}
672682

673683
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
674684
auto newOp =

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ using namespace mlir::xegpu;
1919

2020
namespace {
2121

22-
2322
#define DEBUG_TYPE "test-xegpu-unroll"
2423
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2524
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2625

27-
2826
struct TestXeGPUUnrollingPatterns
2927
: public PassWrapper<TestXeGPUUnrollingPatterns,
3028
OperationPass<gpu::GPUModuleOp>> {
@@ -60,7 +58,8 @@ struct TestXeGPUUnrollingPatterns
6058
xegpu::TensorDescType tdescTy;
6159
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
6260
tdescTy = createNdOp.getType();
63-
} else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
61+
} else if (auto updateNdOp =
62+
dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
6463
tdescTy = updateNdOp.getTensorDescType();
6564
} else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
6665
tdescTy = prefetchNdOp.getTensorDescType();
@@ -105,28 +104,27 @@ struct TestXeGPUUnrollingPatterns
105104
Attribute encoding = tdescTy.getEncoding();
106105
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
107106
tdescTy.getLayout());
108-
107+
109108
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
110-
auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
109+
auto scatterAttr =
110+
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
111111
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
112-
112+
113113
if (chunkSize > 1) {
114114
int64_t newChunkSize = chunkSize;
115115
auto instData = layout.getInstData();
116116
if (!instData.empty())
117-
newChunkSize = instData.asArrayRef().back();
117+
newChunkSize = instData.asArrayRef().back();
118118

119119
auto chunkSizeAttr = mlir::IntegerAttr::get(
120-
mlir::IntegerType::get(ctx, 64), newChunkSize);
120+
mlir::IntegerType::get(ctx, 64), newChunkSize);
121121

122122
// To create a new attribute with a different chunk_size:
123123
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
124124
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
125125

126126
encoding = newEncoding;
127-
128127
}
129-
130128
}
131129
if (layout) {
132130
if (layout.getLaneLayout() == nullptr)
@@ -135,7 +133,8 @@ struct TestXeGPUUnrollingPatterns
135133
layout = layout.dropInstData();
136134
}
137135

138-
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
136+
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
137+
layout);
139138

140139
} else {
141140
newTy = type.clone(tileShape, elemTy);

0 commit comments

Comments
 (0)