Skip to content

Commit 3b12279

Browse files
authored
Bump LLVM to 9344b2196cbc36cdc577314bbb2b889606ba6820. (#8404)
There were series of patches to llvm::EquivalenceClasses this week, and we updated accordingly: * There is no longer a template parameter for the comparator * The idiom for iterating over EquivalenceClasses was updated * The idiom for iterating over members was updated There was a change to the InlinerUtils functions to now include a required callback for ops that are cloned. To update our call sites, this default constructs an InlinerConfig and uses the default clone callback. Finally, there was a change to the memref.subview verifier to actually check bounds when statically possible. Since memref.subview is technically supposed to "represent a reduced-size view of the original memref", I think our previous lowering was technically supposed to be interpreted as out of bound accesses. In any case, it seems this is the expected use for memref.collapse_shape, so this was adopted instead of memref.subview for this use case.
1 parent 5c7db4f commit 3b12279

File tree

8 files changed

+44
-52
lines changed

8 files changed

+44
-52
lines changed

lib/Dialect/FIRRTL/Transforms/CheckCombLoops.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -613,14 +613,12 @@ class DiscoverLoops {
613613
llvm::dbgs() << "\n node:" << getName(drivenBy[node.first].first)
614614
<< "=> probe:" << getName(drivenBy[node.second].first);
615615
}
616-
for (auto i = rwProbeClasses.begin(), e = rwProbeClasses.end(); i != e;
617-
++i) { // Iterate over all of the equivalence sets.
616+
for (const auto &i :
617+
rwProbeClasses) { // Iterate over all of the equivalence sets.
618618
if (!i->isLeader())
619619
continue; // Ignore non-leader sets.
620620
// Print members in this set.
621-
llvm::interleave(llvm::make_range(rwProbeClasses.member_begin(i),
622-
rwProbeClasses.member_end()),
623-
llvm::dbgs(), "\n");
621+
llvm::interleave(rwProbeClasses.members(*i), llvm::dbgs(), "\n");
624622
llvm::dbgs() << "\n dataflow at leader::" << i->getData() << "\n =>"
625623
<< rwProbeRefersTo[i->getData()];
626624
llvm::dbgs() << "\n Done\n"; // Finish set.
@@ -653,10 +651,7 @@ class DiscoverLoops {
653651
rwProbeClasses.getLeaderValue(getOrAddNode(defOp.getDataRef()));
654652
// For all the probes, that are in the same eqv class, i.e., refer to
655653
// the same value.
656-
for (auto probe :
657-
llvm::make_range(rwProbeClasses.member_begin(
658-
rwProbeClasses.findValue(rwProbeNode)),
659-
rwProbeClasses.member_end())) {
654+
for (auto probe : rwProbeClasses.members(rwProbeNode)) {
660655
auto probeVal = drivenBy[probe].first;
661656
// If the probe is a port, then record the path from the probe to the
662657
// input port.

lib/Dialect/FIRRTL/Transforms/InferResets.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,12 +1028,10 @@ LogicalResult InferResetsPass::inferAndUpdateResets() {
10281028
llvm::dbgs() << "\n";
10291029
debugHeader("Infer reset types") << "\n\n";
10301030
});
1031-
for (auto it = resetClasses.begin(), end = resetClasses.end(); it != end;
1032-
++it) {
1031+
for (const auto &it : resetClasses) {
10331032
if (!it->isLeader())
10341033
continue;
1035-
ResetNetwork net = llvm::make_range(resetClasses.member_begin(it),
1036-
resetClasses.member_end());
1034+
ResetNetwork net = resetClasses.members(*it);
10371035

10381036
// Infer whether this should be a sync or async reset.
10391037
auto kind = inferReset(net);

lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase<LowerXMRPass> {
130130
CircuitNamespace ns(getOperation());
131131
circuitNamespace = &ns;
132132

133-
llvm::EquivalenceClasses<Value, ValueComparator> eq;
133+
llvm::EquivalenceClasses<Value> eq;
134134
dataFlowClasses = &eq;
135135

136136
InstanceGraph &instanceGraph = getAnalysis<InstanceGraph>();
@@ -370,14 +370,12 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase<LowerXMRPass> {
370370
}
371371

372372
LLVM_DEBUG({
373-
for (auto I = dataFlowClasses->begin(), E = dataFlowClasses->end();
374-
I != E; ++I) { // Iterate over all of the equivalence sets.
373+
for (const auto &I :
374+
*dataFlowClasses) { // Iterate over all of the equivalence sets.
375375
if (!I->isLeader())
376376
continue; // Ignore non-leader sets.
377377
// Print members in this set.
378-
llvm::interleave(llvm::make_range(dataFlowClasses->member_begin(I),
379-
dataFlowClasses->member_end()),
380-
llvm::dbgs(), "\n");
378+
llvm::interleave(dataFlowClasses->members(*I), llvm::dbgs(), "\n");
381379
llvm::dbgs() << "\n dataflow at leader::" << I->getData() << "\n =>";
382380
auto iter = dataflowAt.find(I->getData());
383381
if (iter != dataflowAt.end()) {
@@ -887,15 +885,7 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase<LowerXMRPass> {
887885
/// no NextNodeOnPath, which denotes a leaf node on the path.
888886
SmallVector<XMRNode> refSendPathList;
889887

890-
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
891-
/// uses pointer comparison on the Impl.
892-
struct ValueComparator {
893-
bool operator()(const Value &lhs, const Value &rhs) const {
894-
return lhs.getImpl() < rhs.getImpl();
895-
}
896-
};
897-
898-
llvm::EquivalenceClasses<Value, ValueComparator> *dataFlowClasses;
888+
llvm::EquivalenceClasses<Value> *dataFlowClasses;
899889
// Instance and module ref ports that needs to be removed.
900890
DenseMap<Operation *, llvm::BitVector> refPortsToRemoveMap;
901891

lib/Dialect/HW/Transforms/FlattenModules.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "circt/Support/BackedgeBuilder.h"
1313
#include "mlir/IR/IRMapping.h"
1414
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/Inliner.h"
1516
#include "mlir/Transforms/InliningUtils.h"
1617
#include "llvm/ADT/PostOrderIterator.h"
1718
#include "llvm/Support/Debug.h"
@@ -28,6 +29,7 @@ namespace hw {
2829
using namespace circt;
2930
using namespace hw;
3031
using namespace igraph;
32+
using mlir::InlinerConfig;
3133
using mlir::InlinerInterface;
3234

3335
namespace {
@@ -97,6 +99,8 @@ void FlattenModulesPass::runOnOperation() {
9799
auto &instanceGraph = getAnalysis<hw::InstanceGraph>();
98100
DenseSet<Operation *> handled;
99101

102+
InlinerConfig config;
103+
100104
// Iterate over all instances in the instance graph. This ensures we visit
101105
// every module, even private top modules (private and never instantiated).
102106
for (auto *startNode : instanceGraph) {
@@ -130,7 +134,8 @@ void FlattenModulesPass::runOnOperation() {
130134
bool isLastModuleUse = --numUsesLeft == 0;
131135

132136
PrefixingInliner inliner(&getContext(), inst.getInstanceName());
133-
if (failed(mlir::inlineRegion(inliner, &module.getBody(), inst,
137+
if (failed(mlir::inlineRegion(inliner, config.getCloneCallback(),
138+
&module.getBody(), inst,
134139
inst.getOperands(), inst.getResults(),
135140
std::nullopt, !isLastModuleUse))) {
136141
inst.emitError("failed to inline '")

lib/Dialect/LLHD/Transforms/FunctionEliminationPass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/Visitors.h"
1919
#include "mlir/Interfaces/CallInterfaces.h"
2020
#include "mlir/Pass/Pass.h"
21+
#include "mlir/Transforms/Inliner.h"
2122
#include "mlir/Transforms/InliningUtils.h"
2223
#include "llvm/Support/LogicalResult.h"
2324

@@ -68,6 +69,7 @@ void FunctionEliminationPass::runOnOperation() {
6869
LogicalResult FunctionEliminationPass::runOnModule(hw::HWModuleOp module) {
6970
FunctionInliner inliner(&getContext());
7071
SymbolTableCollection table;
72+
mlir::InlinerConfig config;
7173

7274
SmallVector<CallOpInterface> calls;
7375
module.walk([&](func::CallOp op) { calls.push_back(op); });
@@ -81,8 +83,8 @@ LogicalResult FunctionEliminationPass::runOnModule(hw::HWModuleOp module) {
8183
auto func = cast<CallableOpInterface>(
8284
table.lookupNearestSymbolFrom(module, symbol.getLeafReference()));
8385

84-
if (succeeded(
85-
mlir::inlineCall(inliner, call, func, func.getCallableRegion()))) {
86+
if (succeeded(mlir::inlineCall(inliner, config.getCloneCallback(), call,
87+
func, func.getCallableRegion()))) {
8688
call->erase();
8789
continue;
8890
}

lib/Transforms/FlattenMemRefs.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"
21+
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
2122
#include "mlir/IR/BuiltinDialect.h"
2223
#include "mlir/IR/BuiltinTypes.h"
2324
#include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -397,26 +398,26 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) {
397398
}
398399

399400
// Materializes a multidimensional memory to unidimensional memory by using a
400-
// memref.subview operation.
401+
// memref.collapse_shape operation.
401402
// TODO: This is also possible for dynamically shaped memories.
402-
static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type,
403-
ValueRange inputs, Location loc) {
403+
static Value materializeCollapseShapeFlattening(OpBuilder &builder,
404+
MemRefType type,
405+
ValueRange inputs,
406+
Location loc) {
404407
assert(type.hasStaticShape() &&
405408
"Can only subview flatten memref's with static shape (for now...).");
406409
MemRefType sourceType = cast<MemRefType>(inputs[0].getType());
407410
int64_t memSize = sourceType.getNumElements();
408-
unsigned dims = sourceType.getShape().size();
411+
ArrayRef<int64_t> sourceShape = sourceType.getShape();
412+
ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize);
409413

410-
// Build offset, sizes and strides
411-
SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0));
412-
SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1));
413-
offsets[offsets.size() - 1] = builder.getIndexAttr(memSize);
414-
SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1));
414+
// Build ReassociationIndices to collapse completely to 1D MemRef.
415+
auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape);
416+
assert(indices.has_value() && "expected a valid collapse");
415417

416418
// Generate the appropriate return type:
417-
MemRefType outType = MemRefType::get({memSize}, type.getElementType());
418-
return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes,
419-
offsets, strides);
419+
return builder.create<memref::CollapseShapeOp>(loc, inputs[0],
420+
indices.value());
420421
}
421422

422423
static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
@@ -489,7 +490,7 @@ struct FlattenMemRefCallsPass
489490

490491
// Add a target materializer to handle memory flattening through
491492
// memref.subview operations.
492-
typeConverter.addTargetMaterialization(materializeSubViewFlattening);
493+
typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening);
493494

494495
if (applyPartialConversion(getOperation(), target, std::move(patterns))
495496
.failed()) {

llvm

Submodule llvm updated 8404 files

test/Transforms/flatten_memref_calls.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
// CHECK-LABEL: func private @foo(memref<900xi32>) -> i32
44

5-
// CHECK-LABEL: func @main() {
6-
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
7-
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32>
8-
// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_1]][0, 0] [1, 900] [1, 1] : memref<30x30xi32> to memref<900xi32>
9-
// CHECK: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<900xi32>) -> i32
10-
// CHECK: return
11-
// CHECK: }
5+
// CHECK-LABEL: func @main() {
6+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
7+
// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32>
8+
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_1]]
9+
// CHECK-SAME{LITERAL}: [[0, 1]] : memref<30x30xi32> into memref<900xi32>
10+
// CHECK: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<900xi32>) -> i32
11+
// CHECK: return
12+
// CHECK: }
1213
module {
1314
func.func private @foo(memref<30x30xi32>) -> i32
1415
func.func @main() {

0 commit comments

Comments
 (0)