Skip to content

Commit 545f937

Browse files
committed
format code
1 parent 1d4dc72 commit 545f937

File tree

3 files changed

+58
-56
lines changed

3 files changed

+58
-56
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace mlir {
1313
class RewritePatternSet;
1414

1515
namespace vector {
16-
struct UnrollVectorOptions;
16+
struct UnrollVectorOptions;
1717
} // namespace vector
1818

1919
namespace xegpu {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
211211
mixedOffsets[x] = addi(oldX, subOffX);
212212
mixedOffsets[y] = addi(oldY, subOffY);
213213
auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
214-
loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(), op.getMixedStrides());
214+
loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
215+
op.getMixedStrides());
215216
newOps.push_back(newOp);
216217
}
217218
}
@@ -304,20 +305,21 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
304305

305306
auto elemTy = tdescTy.getElementType();
306307
auto newValueTy = valueTy.cloneWith(targetShape, elemTy);
307-
auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy, tdescTy.getEncoding(),
308-
getLaneLayoutAttr(layout));
308+
auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy,
309+
tdescTy.getEncoding(),
310+
getLaneLayoutAttr(layout));
309311

310312
auto numNewOps = computeProduct(grids);
311313
llvm::SmallVector<Type> convertedValTypes(numNewOps, newValueTy);
312314
llvm::SmallVector<Type> convertedTdescTypes(numNewOps, newTdescTy);
313-
auto convertedValues = addPackOp(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
315+
auto convertedValues =
316+
addPackOp(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
314317
auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTdescTypes,
315318
targetShape, loc, rewriter);
316319

317320
for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
318321
rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
319-
op.getL2HintAttr(),
320-
op.getL3HintAttr());
322+
op.getL2HintAttr(), op.getL3HintAttr());
321323
}
322324
rewriter.eraseOp(op);
323325
return success();
@@ -395,27 +397,27 @@ struct XeGPUUnrollPass final
395397

396398
void runOnOperation() override {
397399
vector::UnrollVectorOptions options;
398-
options.setNativeShapeFn(
399-
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
400-
if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
401-
xegpu::TensorDescType tdescTy;
402-
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
403-
tdescTy = createNdOp.getType();
404-
} else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
405-
tdescTy = loadNdOp.getTensorDescType();
406-
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
407-
tdescTy = storeNdOp.getTensorDescType();
408-
}
409-
410-
if (auto layout = tdescTy.getLayoutAttr()) {
411-
if (auto inst_data = layout.getInstData())
412-
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
413-
inst_data.asArrayRef().end());
414-
}
415-
}
416-
417-
return std::nullopt;
418-
});
400+
options.setNativeShapeFn([&](Operation *op)
401+
-> std::optional<SmallVector<int64_t>> {
402+
if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
403+
xegpu::TensorDescType tdescTy;
404+
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
405+
tdescTy = createNdOp.getType();
406+
} else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
407+
tdescTy = loadNdOp.getTensorDescType();
408+
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
409+
tdescTy = storeNdOp.getTensorDescType();
410+
}
411+
412+
if (auto layout = tdescTy.getLayoutAttr()) {
413+
if (auto inst_data = layout.getInstData())
414+
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
415+
inst_data.asArrayRef().end());
416+
}
417+
}
418+
419+
return std::nullopt;
420+
});
419421

420422
auto funcOp = getOperation();
421423
RewritePatternSet patterns(&getContext());
@@ -432,7 +434,8 @@ struct XeGPUUnrollPass final
432434
} // namespace
433435

434436
void mlir::xegpu::populateXeGPUUnrollPatterns(
435-
RewritePatternSet &patterns, const mlir::vector::UnrollVectorOptions &options) {
437+
RewritePatternSet &patterns,
438+
const mlir::vector::UnrollVectorOptions &options) {
436439
patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
437-
patterns.getContext(), options);
440+
patterns.getContext(), options);
438441
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1011
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1112
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
12-
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
13-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1413
#include "mlir/Pass/Pass.h"
1514
#include "mlir/Pass/PassManager.h"
15+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1616

1717
using namespace mlir;
1818
using namespace mlir::xegpu;
@@ -44,28 +44,27 @@ struct TestXeGPUUnrollingPatterns
4444

4545
void runOnOperation() override {
4646
vector::UnrollVectorOptions options;
47-
options.setNativeShapeFn(
48-
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
49-
if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
50-
xegpu::TensorDescType tdescTy;
51-
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
52-
tdescTy = createNdOp.getType();
53-
} else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
54-
tdescTy = loadNdOp.getTensorDescType();
55-
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
56-
tdescTy = storeNdOp.getTensorDescType();
57-
}
58-
59-
if (auto layout = tdescTy.getLayoutAttr()) {
60-
if (auto inst_data = layout.getInstData())
61-
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
62-
inst_data.asArrayRef().end());
63-
}
64-
}
65-
66-
return std::nullopt;
67-
});
68-
47+
options.setNativeShapeFn([&](Operation *op)
48+
-> std::optional<SmallVector<int64_t>> {
49+
if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
50+
xegpu::TensorDescType tdescTy;
51+
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
52+
tdescTy = createNdOp.getType();
53+
} else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
54+
tdescTy = loadNdOp.getTensorDescType();
55+
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
56+
tdescTy = storeNdOp.getTensorDescType();
57+
}
58+
59+
if (auto layout = tdescTy.getLayoutAttr()) {
60+
if (auto inst_data = layout.getInstData())
61+
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
62+
inst_data.asArrayRef().end());
63+
}
64+
}
65+
66+
return std::nullopt;
67+
});
6968

7069
MLIRContext *ctx = &getContext();
7170
RewritePatternSet patterns(ctx);
@@ -82,5 +81,5 @@ namespace test {
8281
void registerTestXeGPULowerings() {
8382
PassRegistration<TestXeGPUUnrollingPatterns>();
8483
}
85-
}
86-
}
84+
} // namespace test
85+
} // namespace mlir

0 commit comments

Comments
 (0)