Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/TableGen/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,10 @@ class Pattern {
using IdentifierLine = std::pair<StringRef, unsigned>;

// Returns the file location of the pattern (buffer identifier + line number
// pair).
std::vector<IdentifierLine> getLocation() const;
// pair). If `forSourceOutput` is true, replace absolute paths in the buffer
// identifier with just their filename so that we don't leak build paths into
// the generated code.
std::vector<IdentifierLine> getLocation(bool forSourceOutput = false) const;

// Recursively collects all bound symbols inside the DAG tree rooted
// at `tree` and updates the given `infoMap`.
Expand Down
21 changes: 17 additions & 4 deletions mlir/lib/TableGen/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"

Expand Down Expand Up @@ -771,15 +772,27 @@ int Pattern::getBenefit() const {
return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
}

std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
std::vector<Pattern::IdentifierLine>
Pattern::getLocation(bool forSourceOutput) const {
std::vector<std::pair<StringRef, unsigned>> result;
result.reserve(def.getLoc().size());
for (auto loc : def.getLoc()) {
unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
assert(buf && "invalid source location");
result.emplace_back(
llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
llvm::SrcMgr.getLineAndColumn(loc, buf).first);

StringRef bufferName =
llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier();
// If we're emitting a generated file, we'd like to have some indication of
// where our patterns came from. However, LLVM's build rules use absolute
// paths as arguments to TableGen, and naively echoing such paths makes the
// contents of the generated source file depend on the build location,
// making MLIR builds substantially less reproducable. As a compromise, we
// trim absolute paths back to only the filename component.
if (forSourceOutput && llvm::sys::path::is_absolute(bufferName))
bufferName = llvm::sys::path::filename(bufferName);

result.emplace_back(bufferName,
llvm::SrcMgr.getLineAndColumn(loc, buf).first);
}
return result;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-tblgen/RewriterGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");

// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
auto locs = pattern.getLocation(/*forSourceOutput=*/true);
os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
llvm::reverse(locs));
os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
Expand Down