Skip to content

Commit 8c487ff

Browse files
committed
operand reordering in alphabetical order
1 parent dab3957 commit 8c487ff

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

mlir/lib/Conversion/Normalize/Normalize.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ struct NormalizePass : public impl::NormalizeBase<NormalizePass> {
5454
void nameAsInitialOperation(mlir::Operation* op);
5555
void nameAsRegularOperation(mlir::Operation* op, llvm::SmallPtrSet<const mlir::Operation *, 32> &visited);
5656
bool hasOnlyImmediateOperands(mlir::Operation* op);
57-
void SetDeterministicNames(Block &block);
5857
llvm::SetVector<int> getOutputFootprint(mlir::Operation* op, llvm::SmallPtrSet<const mlir::Operation *, 32> &visited);
5958
void foldOperation(mlir::Operation* op);
59+
void reorderOperationOperandsByName(mlir::Operation* op);
6060
mlir::OpPrintingFlags flags{};
6161
};
6262
} // namespace
@@ -81,18 +81,11 @@ void NormalizePass::runOnOperation() {
8181
for (Block &block : region)
8282
for (Operation &innerOp : block)
8383
foldOperation(&innerOp);
84-
}
85-
}
8684

87-
void NormalizePass::SetDeterministicNames(Block &block) {
88-
static size_t VarCounter = 0;
89-
90-
for (Operation &innerOp : block) {
91-
mlir::OpBuilder b(innerOp.getContext());
92-
mlir::StringAttr sat =
93-
b.getStringAttr(llvm::formatv("v{0}", VarCounter++).str());
94-
mlir::Location newLoc = mlir::NameLoc::get(sat, innerOp.getLoc());
95-
innerOp.setLoc(newLoc);
85+
for (Region& region : op.getRegions())
86+
for (Block &block : region)
87+
for (Operation &innerOp : block)
88+
reorderOperationOperandsByName(&innerOp);
9689
}
9790
}
9891

@@ -368,6 +361,30 @@ void NormalizePass::foldOperation(mlir::Operation* op) {
368361
op->setLoc(newLoc);
369362
}
370363

364+
void NormalizePass::reorderOperationOperandsByName(mlir::Operation* op) {
365+
if(op->getNumOperands() == 0) return;
366+
367+
SmallVector<std::pair<std::string, mlir::Value>, 4> Operands;
368+
369+
for (mlir::Value operand : op->getOperands()) {
370+
std::string TextRepresentation;
371+
llvm::raw_string_ostream Stream(TextRepresentation);
372+
operand.printAsOperand(Stream, flags);
373+
Operands.push_back({Stream.str(), operand});
374+
}
375+
376+
if (op->hasTrait<OpTrait::IsCommutative>()) {
377+
llvm::sort(Operands.begin(), Operands.end(),
378+
[](const auto &a, const auto &b) {
379+
return llvm::StringRef(a.first).compare_insensitive(b.first) < 0;
380+
});
381+
}
382+
383+
for(size_t i = 0; i < Operands.size(); i++) {
384+
op->setOperand(i, Operands[i].second);
385+
}
386+
}
387+
371388
void NormalizePass::reorderOperations(SmallVector<Operation *, 16> &Outputs) {
372389
llvm::SmallPtrSet<const mlir::Operation *, 32> visited;
373390
for (auto *op : Outputs)

0 commit comments

Comments
 (0)