Skip to content

Commit 9ed096d

Browse files
committed
refactor repeated logic into foldOperations
1 parent f4d0c63 commit 9ed096d

File tree

2 files changed

+38
-95
lines changed

2 files changed

+38
-95
lines changed

mlir/lib/Conversion/Normalize/Normalize.cpp

Lines changed: 35 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -163,51 +163,10 @@ std::string inline split(std::string_view str, const char &delimiter,
163163
void NormalizePass::nameAsInitialOperation(
164164
mlir::Operation *op,
165165
llvm::SmallPtrSet<const mlir::Operation *, 32> &visited) {
166-
SmallVector<SmallString<64>, 4> Operands;
167166

168-
if (op->getNumOperands() == 0) {
169-
if (auto call = mlir::dyn_cast<mlir::func::CallOp>(op)) {
170-
Operands.push_back(StringRef(std::string{"void"}));
171-
} else {
172-
std::string TextRepresentation;
173-
mlir::AsmState state(op, flags);
174-
llvm::raw_string_ostream Stream(TextRepresentation);
175-
op->print(Stream, state);
176-
std::string hash = to_string(strHash(split(Stream.str(), '=', 1)));
177-
Operands.push_back(StringRef(hash));
178-
}
179-
} else {
180-
for (mlir::Value operand : op->getOperands()) {
181-
if (mlir::Operation *defOp = operand.getDefiningOp()) {
182-
RenameOperation(defOp, visited);
183-
184-
std::string TextRepresentation;
185-
mlir::AsmState state(defOp, flags);
186-
llvm::raw_string_ostream Stream(TextRepresentation);
187-
defOp->print(Stream, state);
188-
Operands.push_back(StringRef(split(Stream.str(), '=', 0)));
189-
} else if (auto ba = dyn_cast<mlir::BlockArgument>(operand)) {
190-
mlir::Block *ownerBlock = ba.getOwner();
191-
unsigned argIndex = ba.getArgNumber();
192-
if (auto func =
193-
dyn_cast<mlir::func::FuncOp>(ownerBlock->getParentOp())) {
194-
if (&func.front() == ownerBlock) {
195-
Operands.push_back(
196-
StringRef(std::string("funcArg" + std::to_string(argIndex))));
197-
} else {
198-
Operands.push_back(
199-
StringRef(std::string("blockArg" + std::to_string(argIndex))));
200-
}
201-
} else {
202-
Operands.push_back(
203-
StringRef(std::string("blockArg" + std::to_string(argIndex))));
204-
}
205-
}
206-
}
207-
}
208-
209-
if (op->hasTrait<OpTrait::IsCommutative>())
210-
llvm::sort(Operands);
167+
for (mlir::Value operand : op->getOperands())
168+
if (mlir::Operation *defOp = operand.getDefiningOp())
169+
RenameOperation(defOp, visited);
211170

212171
uint64_t Hash = MagicHashConstant;
213172

@@ -228,14 +187,20 @@ void NormalizePass::nameAsInitialOperation(
228187
Name.append(callee.str());
229188
}
230189

231-
Name.append("$");
232-
for (unsigned long i = 0; i < Operands.size(); ++i) {
233-
Name.append(std::string(Operands[i]));
234-
235-
if (i < Operands.size() - 1)
236-
Name.append("-");
190+
if (op->getNumOperands() == 0) {
191+
Name.append("$");
192+
if (auto call = mlir::dyn_cast<mlir::func::CallOp>(op)) {
193+
Name.append("void");
194+
} else {
195+
std::string TextRepresentation;
196+
mlir::AsmState state(op, flags);
197+
llvm::raw_string_ostream Stream(TextRepresentation);
198+
op->print(Stream, state);
199+
std::string hash = to_string(strHash(split(Stream.str(), '=', 1)));
200+
Name.append(hash);
201+
}
202+
Name.append("$");
237203
}
238-
Name.append("$");
239204

240205
mlir::OpBuilder b(op->getContext());
241206
mlir::StringAttr sat = b.getStringAttr(Name);
@@ -246,36 +211,10 @@ void NormalizePass::nameAsInitialOperation(
246211
void NormalizePass::nameAsRegularOperation(
247212
mlir::Operation *op,
248213
llvm::SmallPtrSet<const mlir::Operation *, 32> &visited) {
249-
SmallVector<SmallString<64>, 4> Operands;
250-
for (mlir::Value operand : op->getOperands()) {
251-
if (mlir::Operation *defOp = operand.getDefiningOp()) {
252-
RenameOperation(defOp, visited);
253-
254-
std::string TextRepresentation;
255-
mlir::AsmState state(defOp, flags);
256-
llvm::raw_string_ostream Stream(TextRepresentation);
257-
defOp->print(Stream, state);
258-
Operands.push_back(StringRef(split(Stream.str(), '=', 0)));
259-
} else if (auto ba = dyn_cast<mlir::BlockArgument>(operand)) {
260-
mlir::Block *ownerBlock = ba.getOwner();
261-
unsigned argIndex = ba.getArgNumber();
262-
if (auto func = dyn_cast<mlir::func::FuncOp>(ownerBlock->getParentOp())) {
263-
if (&func.front() == ownerBlock) {
264-
Operands.push_back(
265-
StringRef(std::string("funcArg" + std::to_string(argIndex))));
266-
} else {
267-
Operands.push_back(
268-
StringRef(std::string("blockArg" + std::to_string(argIndex))));
269-
}
270-
} else {
271-
Operands.push_back(
272-
StringRef(std::string("blockArg" + std::to_string(argIndex))));
273-
}
274-
}
275-
}
276214

277-
if (op->hasTrait<OpTrait::IsCommutative>())
278-
llvm::sort(Operands);
215+
for (mlir::Value operand : op->getOperands())
216+
if (mlir::Operation *defOp = operand.getDefiningOp())
217+
RenameOperation(defOp, visited);
279218

280219
uint64_t Hash = MagicHashConstant;
281220

@@ -302,15 +241,6 @@ void NormalizePass::nameAsRegularOperation(
302241
Name.append(callee.str());
303242
}
304243

305-
Name.append("$");
306-
for (unsigned long i = 0; i < Operands.size(); ++i) {
307-
Name.append(Operands[i]);
308-
309-
if (i < Operands.size() - 1)
310-
Name.append("-");
311-
}
312-
Name.append("$");
313-
314244
mlir::OpBuilder b(op->getContext());
315245
mlir::StringAttr sat = b.getStringAttr(Name);
316246
mlir::Location newLoc = mlir::NameLoc::get(sat, op->getLoc());
@@ -324,7 +254,7 @@ bool inline starts_with(std::string_view base,
324254
}
325255

326256
void NormalizePass::foldOperation(mlir::Operation *op) {
327-
if (isOutput(*op))
257+
if (isOutput(*op) || op->getNumOperands() == 0)
328258
return;
329259

330260
std::string TextRepresentation;
@@ -333,7 +263,7 @@ void NormalizePass::foldOperation(mlir::Operation *op) {
333263
op->print(Stream, state);
334264

335265
auto opName = split(Stream.str(), '=', 0);
336-
if (!starts_with(opName, "%op"))
266+
if (!starts_with(opName, "%op") && !starts_with(opName, "%vl"))
337267
return;
338268

339269
SmallVector<std::string, 4> Operands;
@@ -354,7 +284,20 @@ void NormalizePass::foldOperation(mlir::Operation *op) {
354284
} else {
355285
Operands.push_back(name);
356286
}
357-
}
287+
} else if (auto ba = dyn_cast<mlir::BlockArgument>(operand)) {
288+
mlir::Block *ownerBlock = ba.getOwner();
289+
unsigned argIndex = ba.getArgNumber();
290+
if (auto func =
291+
dyn_cast<mlir::func::FuncOp>(ownerBlock->getParentOp())) {
292+
if (&func.front() == ownerBlock) {
293+
Operands.push_back(std::string("funcArg" + std::to_string(argIndex)));
294+
} else {
295+
Operands.push_back(std::string("blockArg" + std::to_string(argIndex)));
296+
}
297+
} else {
298+
Operands.push_back(std::string("blockArg" + std::to_string(argIndex)));
299+
}
300+
}
358301
}
359302

360303
if (op->hasTrait<OpTrait::IsCommutative>())

mlir/test/Conversion/Normalize/reorder.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
// CHECK-LABEL: func.func @bar(
44
// CHECK-SAME: %[[ARG0:.*]]: i32) -> i32 {
5-
// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
6-
// CHECK: %vl15831$51356-funcArg0$ = arith.addi %[[ARG0]], %[[VAL_0:.*]] : i32
5+
// CHECK: %vl14084$51356$ = arith.constant 2 : i32
6+
// CHECK: %vl15831$funcArg0-vl14084$ = arith.addi %[[ARG0]], %vl14084$51356$ : i32
77
// CHECK: %vl14084$187c2$ = arith.constant 6 : i32
8-
// CHECK: %op27844$vl14084-vl15831$ = arith.addi %vl14084$187c2$, %vl15831$51356-funcArg0$ : i32
8+
// CHECK: %op27844$vl14084-vl15831$ = arith.addi %vl14084$187c2$, %vl15831$funcArg0-vl14084$ : i32
99
// CHECK: %vl14084$4c6ac$ = arith.constant 8 : i32
1010
// CHECK: %op27844$op27844-vl14084$ = arith.addi %op27844$vl14084-vl15831$, %vl14084$4c6ac$ : i32
1111
// CHECK: return %op27844$op27844-vl14084$ : i32

0 commit comments

Comments
 (0)