Skip to content

Commit 7069062

Browse files
committed
Fix format and C++20 issue
1 parent 05cdabb commit 7069062

File tree

3 files changed

+71
-64
lines changed

3 files changed

+71
-64
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
#include "mlir/Support/LLVM.h"
1919
#include "mlir/Support/LogicalResult.h"
2020
#include "llvm/ADT/ArrayRef.h"
21-
#include "llvm/ADT/SmallVector.h"
2221
#include "llvm/ADT/STLFunctionalExtras.h"
22+
#include "llvm/ADT/SmallVector.h"
2323

2424
namespace mlir {
2525
class Location;
@@ -236,9 +236,10 @@ memref::AllocaOp allocToAlloca(
236236
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
237237
/// this is not possible because this function uses the Affine dialect and the
238238
/// MemRef dialect cannot depend on the Affine dialect.
239-
SmallVector<OpFoldResult>
240-
getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
241-
ArrayRef<OpFoldResult> origSizes, unsigned groupId);
239+
SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
240+
OpBuilder &builder,
241+
ArrayRef<OpFoldResult> origSizes,
242+
unsigned groupId);
242243

243244
/// Compute the expanded strides of the given \p expandShape for the
244245
/// \p groupId-th reassociation group.
@@ -277,11 +278,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
277278
///
278279
/// \pre for all index in indices: index < values.size()
279280
/// \pre for all index in indices: index < maybeConstants.size()
280-
OpFoldResult
281-
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
282-
ArrayRef<int64_t> maybeConstants,
283-
ArrayRef<OpFoldResult> values,
284-
llvm::function_ref<bool(int64_t)> isDynamic);
281+
OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
282+
Location loc, ArrayRef<int64_t> maybeConstants,
283+
ArrayRef<OpFoldResult> values,
284+
llvm::function_ref<bool(int64_t)> isDynamic);
285285

286286
/// Compute the collapsed size of the given \p collapseShape for the
287287
/// \p groupId-th reassociation group.
@@ -291,9 +291,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
291291
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
292292
/// this is not possible because this function uses the Affine dialect and the
293293
/// MemRef dialect cannot depend on the Affine dialect.
294-
SmallVector<OpFoldResult>
295-
getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
296-
ArrayRef<OpFoldResult> origSizes, unsigned groupId);
294+
SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
295+
OpBuilder &builder,
296+
ArrayRef<OpFoldResult> origSizes,
297+
unsigned groupId);
297298

298299
/// Compute the collapsed stride of the given \p collpaseShape for the
299300
/// \p groupId-th reassociation group.
@@ -307,10 +308,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
307308
///
308309
/// \post result.size() == 1, in other words, each group collapse to one
309310
/// dimension.
310-
SmallVector<OpFoldResult>
311-
getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
312-
ArrayRef<OpFoldResult> origSizes,
313-
ArrayRef<OpFoldResult> origStrides, unsigned groupId);
311+
SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
312+
OpBuilder &builder,
313+
ArrayRef<OpFoldResult> origSizes,
314+
ArrayRef<OpFoldResult> origStrides,
315+
unsigned groupId);
314316

315317
} // namespace memref
316318
} // namespace mlir

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
#include "mlir/Dialect/Utils/IndexingUtils.h"
2222
#include "mlir/IR/AffineMap.h"
2323
#include "mlir/IR/BuiltinTypes.h"
24+
#include "mlir/IR/OpDefinition.h"
2425
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2526
#include "llvm/ADT/STLExtras.h"
2627
#include "llvm/ADT/SmallBitVector.h"
27-
#include "mlir/IR/OpDefinition.h"
2828
#include <optional>
2929

3030
namespace mlir {
@@ -256,9 +256,10 @@ struct ExtractStridedMetadataOpSubviewFolder
256256

257257
namespace mlir {
258258
namespace memref {
259-
SmallVector<OpFoldResult>
260-
getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
261-
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
259+
SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
260+
OpBuilder &builder,
261+
ArrayRef<OpFoldResult> origSizes,
262+
unsigned groupId) {
262263
SmallVector<int64_t, 2> reassocGroup =
263264
expandShape.getReassociationIndices()[groupId];
264265
assert(!reassocGroup.empty() &&
@@ -372,11 +373,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
372373
return expandedStrides;
373374
}
374375

375-
OpFoldResult
376-
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
377-
ArrayRef<int64_t> maybeConstants,
378-
ArrayRef<OpFoldResult> values,
379-
llvm::function_ref<bool(int64_t)> isDynamic) {
376+
OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
377+
Location loc, ArrayRef<int64_t> maybeConstants,
378+
ArrayRef<OpFoldResult> values,
379+
llvm::function_ref<bool(int64_t)> isDynamic) {
380380
AffineExpr productOfValues = builder.getAffineConstantExpr(1);
381381
SmallVector<OpFoldResult> inputValues;
382382
unsigned numberOfSymbols = 0;
@@ -410,9 +410,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
410410
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
411411
/// this is not possible because this function uses the Affine dialect and the
412412
/// MemRef dialect cannot depend on the Affine dialect.
413-
SmallVector<OpFoldResult>
414-
getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
415-
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
413+
SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
414+
OpBuilder &builder,
415+
ArrayRef<OpFoldResult> origSizes,
416+
unsigned groupId) {
416417
SmallVector<OpFoldResult> collapsedSize;
417418

418419
MemRefType collapseShapeType = collapseShape.getResultType();
@@ -451,10 +452,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
451452
///
452453
/// \post result.size() == 1, in other words, each group collapse to one
453454
/// dimension.
454-
SmallVector<OpFoldResult>
455-
getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
456-
ArrayRef<OpFoldResult> origSizes,
457-
ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
455+
SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
456+
OpBuilder &builder,
457+
ArrayRef<OpFoldResult> origSizes,
458+
ArrayRef<OpFoldResult> origStrides,
459+
unsigned groupId) {
458460
SmallVector<int64_t, 2> reassocGroup =
459461
collapseShape.getReassociationIndices()[groupId];
460462
assert(!reassocGroup.empty() &&

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2222
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2323
#include "mlir/IR/Attributes.h"
24-
#include "mlir/IR/DialectResourceBlobManager.h"
2524
#include "mlir/IR/Builders.h"
2625
#include "mlir/IR/BuiltinTypes.h"
26+
#include "mlir/IR/DialectResourceBlobManager.h"
2727
#include "mlir/IR/OpDefinition.h"
2828
#include "mlir/IR/PatternMatch.h"
2929
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,7 +48,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
4848
return cast<Value>(in);
4949
}
5050

51-
5251
/// Returns a collapsed memref and the linearized index to access the element
5352
/// at the specified indices.
5453
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -281,9 +280,8 @@ struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
281280
return {};
282281
}
283282

284-
LogicalResult
285-
matchAndRewrite(memref::GlobalOp globalOp,
286-
PatternRewriter &rewriter) const override {
283+
LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
284+
PatternRewriter &rewriter) const override {
287285
auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
288286
if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
289287
return failure();
@@ -314,7 +312,8 @@ struct FlattenCollapseShape final
314312
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
315313

316314
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
317-
SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
315+
SmallVector<OpFoldResult> origStrides =
316+
metadata.getConstifiedMixedStrides();
318317
OpFoldResult offset = metadata.getConstifiedMixedOffset();
319318

320319
SmallVector<OpFoldResult> collapsedSizes;
@@ -338,7 +337,8 @@ struct FlattenCollapseShape final
338337
}
339338
};
340339

341-
struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
340+
struct FlattenExpandShape final
341+
: public OpRewritePattern<memref::ExpandShapeOp> {
342342
using OpRewritePattern::OpRewritePattern;
343343

344344
LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
@@ -348,7 +348,8 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
348348
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
349349

350350
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
351-
SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
351+
SmallVector<OpFoldResult> origStrides =
352+
metadata.getConstifiedMixedStrides();
352353
OpFoldResult offset = metadata.getConstifiedMixedOffset();
353354

354355
SmallVector<OpFoldResult> expandedSizes;
@@ -372,7 +373,6 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
372373
}
373374
};
374375

375-
376376
// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
377377
struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
378378
using OpRewritePattern::OpRewritePattern;
@@ -405,9 +405,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
405405
for (OpFoldResult ofr : mixedOffsets)
406406
offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
407407

408-
auto [flatSource, linearOffset] =
409-
getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
410-
ValueRange(offsetValues));
408+
auto [flatSource, linearOffset] = getFlattenMemrefAndOffset(
409+
rewriter, loc, op.getSource(), ValueRange(offsetValues));
411410

412411
memref::ExtractStridedMetadataOp sourceMetadata =
413412
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
@@ -424,9 +423,14 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
424423
resultStrides.reserve(resultType.getRank());
425424

426425
OpFoldResult resultOffset = sourceOffset;
427-
for (auto [idx, it] : llvm::enumerate(llvm::zip_equal(
426+
for (auto zipped : llvm::enumerate(llvm::zip_equal(
428427
mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
429-
auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
428+
auto idx = zipped.index();
429+
auto it = zipped.value();
430+
auto offsetOfr = std::get<0>(it);
431+
auto strideOfr = std::get<1>(it);
432+
auto sizeOfr = std::get<2>(it);
433+
auto relativeStrideOfr = std::get<3>(it);
430434
OpFoldResult contribution = [&]() -> OpFoldResult {
431435
if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
432436
if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
@@ -449,7 +453,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
449453
}
450454
}
451455
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
452-
Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
456+
Value contribVal =
457+
getValueFromOpFoldResult(rewriter, loc, contribution);
453458
return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
454459
.getResult();
455460
}();
@@ -478,12 +483,12 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
478483
memref::LinearizedMemRefInfo linearizedInfo;
479484
[[maybe_unused]] OpFoldResult linearizedIndex;
480485
std::tie(linearizedInfo, linearizedIndex) =
481-
memref::getLinearizedMemRefOffsetAndSize(
482-
rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
483-
resultSizes, resultStrides);
486+
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
487+
elementBitWidth, resultOffset,
488+
resultSizes, resultStrides);
484489

485-
Value flattenedSize = getValueFromOpFoldResult(
486-
rewriter, loc, linearizedInfo.linearizedSize);
490+
Value flattenedSize =
491+
getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize);
487492
Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
488493

489494
Value flattenedSubview = memref::SubViewOp::create(
@@ -524,10 +529,11 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
524529
using OpRewritePattern::OpRewritePattern;
525530

526531
LogicalResult matchAndRewrite(memref::GetGlobalOp op,
527-
PatternRewriter &rewriter) const override {
532+
PatternRewriter &rewriter) const override {
528533
// Check if this get_global references a multi-dimensional global
529534
auto module = op->template getParentOfType<ModuleOp>();
530-
auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
535+
auto globalOp =
536+
module.template lookupSymbol<memref::GlobalOp>(op.getName());
531537
if (!globalOp) {
532538
return failure();
533539
}
@@ -537,12 +543,13 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
537543

538544
// Only apply if the global has been flattened but the get_global hasn't
539545
if (globalType.getRank() == 1 && resultType.getRank() > 1) {
540-
auto newGetGlobal = memref::GetGlobalOp::create(
541-
rewriter, op.getLoc(), globalType, op.getName());
546+
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, op.getLoc(),
547+
globalType, op.getName());
542548

543549
// Cast the flattened result back to the original shape
544550
memref::ExtractStridedMetadataOp stridedMetadata =
545-
memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
551+
memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(),
552+
op.getResult());
546553
auto castResult = memref::ReinterpretCastOp::create(
547554
rewriter, op.getLoc(), resultType, newGetGlobal,
548555
/*offset=*/rewriter.getIndexAttr(0),
@@ -572,13 +579,9 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
572579
MemRefRewritePattern<memref::StoreOp>,
573580
MemRefRewritePattern<memref::AllocOp>,
574581
MemRefRewritePattern<memref::AllocaOp>,
575-
MemRefRewritePattern<memref::DeallocOp>,
576-
FlattenExpandShape,
577-
FlattenCollapseShape,
578-
FlattenSubView,
579-
FlattenGetGlobal,
580-
FlattenGlobal>(
581-
patterns.getContext());
582+
MemRefRewritePattern<memref::DeallocOp>, FlattenExpandShape,
583+
FlattenCollapseShape, FlattenSubView, FlattenGetGlobal,
584+
FlattenGlobal>(patterns.getContext());
582585
}
583586

584587
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {

0 commit comments

Comments
 (0)