@@ -54,9 +54,9 @@ struct NormalizePass : public impl::NormalizeBase<NormalizePass> {
5454 void nameAsInitialOperation (mlir::Operation* op);
5555 void nameAsRegularOperation (mlir::Operation* op, llvm::SmallPtrSet<const mlir::Operation *, 32 > &visited);
5656 bool hasOnlyImmediateOperands (mlir::Operation* op);
57- void SetDeterministicNames (Block &block);
5857 llvm::SetVector<int > getOutputFootprint (mlir::Operation* op, llvm::SmallPtrSet<const mlir::Operation *, 32 > &visited);
5958 void foldOperation (mlir::Operation* op);
59+ void reorderOperationOperandsByName (mlir::Operation* op);
6060 mlir::OpPrintingFlags flags{};
6161};
6262} // namespace
@@ -81,18 +81,11 @@ void NormalizePass::runOnOperation() {
8181 for (Block &block : region)
8282 for (Operation &innerOp : block)
8383 foldOperation (&innerOp);
84- }
85- }
8684
87- void NormalizePass::SetDeterministicNames (Block &block) {
88- static size_t VarCounter = 0 ;
89-
90- for (Operation &innerOp : block) {
91- mlir::OpBuilder b (innerOp.getContext ());
92- mlir::StringAttr sat =
93- b.getStringAttr (llvm::formatv (" v{0}" , VarCounter++).str ());
94- mlir::Location newLoc = mlir::NameLoc::get (sat, innerOp.getLoc ());
95- innerOp.setLoc (newLoc);
85+ for (Region& region : op.getRegions ())
86+ for (Block &block : region)
87+ for (Operation &innerOp : block)
88+ reorderOperationOperandsByName (&innerOp);
9689 }
9790}
9891
@@ -368,6 +361,30 @@ void NormalizePass::foldOperation(mlir::Operation* op) {
368361 op->setLoc (newLoc);
369362}
370363
364+ void NormalizePass::reorderOperationOperandsByName (mlir::Operation* op) {
365+ if (op->getNumOperands () == 0 ) return ;
366+
367+ SmallVector<std::pair<std::string, mlir::Value>, 4 > Operands;
368+
369+ for (mlir::Value operand : op->getOperands ()) {
370+ std::string TextRepresentation;
371+ llvm::raw_string_ostream Stream (TextRepresentation);
372+ operand.printAsOperand (Stream, flags);
373+ Operands.push_back ({Stream.str (), operand});
374+ }
375+
376+ if (op->hasTrait <OpTrait::IsCommutative>()) {
377+ llvm::sort (Operands.begin (), Operands.end (),
378+ [](const auto &a, const auto &b) {
379+ return llvm::StringRef (a.first ).compare_insensitive (b.first ) < 0 ;
380+ });
381+ }
382+
383+ for (size_t i = 0 ; i < Operands.size (); i++) {
384+ op->setOperand (i, Operands[i].second );
385+ }
386+ }
387+
371388void NormalizePass::reorderOperations (SmallVector<Operation *, 16 > &Outputs) {
372389 llvm::SmallPtrSet<const mlir::Operation *, 32 > visited;
373390 for (auto *op : Outputs)
0 commit comments