Skip to content

Commit a40f47c

Browse files
[mlir][python] automatic location inference (#151246)
This PR implements "automatic" location inference in the bindings. The way it works is it walks the frame stack collecting source locations (Python captures these in the frame itself). It is inspired by JAX's [implementation](https://github.com/jax-ml/jax/blob/523ddcfbcad005deab5a7d542df4c706f5ee5e9c/jax/_src/interpreters/mlir.py#L462) but moves the frame stack traversal into the bindings for better performance. The system supports registering "included" and "excluded" filenames; frames originating from functions in included filenames **will not** be filtered and frames originating from functions in excluded filenames **will** be filtered (in that order). This allows excluding all the generated `*_ops_gen.py` files. The system is also "toggleable" and off by default to save people who have their own systems (such as JAX) from the added cost. Note, the system stores the entire stacktrace (subject to `locTracebackFramesLimit`) in the `Location` using specifically a `CallSiteLoc`. This can be useful for profiling tools (flamegraphs etc.). Shoutout to the folks at JAX for coming up with a good system. --------- Co-authored-by: Jacques Pienaar <[email protected]>
1 parent da3182a commit a40f47c

File tree

10 files changed

+347
-57
lines changed

10 files changed

+347
-57
lines changed

mlir/lib/Bindings/Python/Globals.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

1212
#include <optional>
13+
#include <regex>
1314
#include <string>
15+
#include <unordered_set>
1416
#include <vector>
1517

1618
#include "NanobindUtils.h"
1719
#include "mlir-c/IR.h"
1820
#include "mlir/CAPI/Support.h"
1921
#include "llvm/ADT/DenseMap.h"
22+
#include "llvm/ADT/StringExtras.h"
2023
#include "llvm/ADT/StringRef.h"
2124
#include "llvm/ADT/StringSet.h"
25+
#include "llvm/Support/Regex.h"
2226

2327
namespace mlir {
2428
namespace python {
@@ -114,6 +118,39 @@ class PyGlobals {
114118
std::optional<nanobind::object>
115119
lookupOperationClass(llvm::StringRef operationName);
116120

121+
class TracebackLoc {
122+
public:
123+
bool locTracebacksEnabled();
124+
125+
void setLocTracebacksEnabled(bool value);
126+
127+
size_t locTracebackFramesLimit();
128+
129+
void setLocTracebackFramesLimit(size_t value);
130+
131+
void registerTracebackFileInclusion(const std::string &file);
132+
133+
void registerTracebackFileExclusion(const std::string &file);
134+
135+
bool isUserTracebackFilename(llvm::StringRef file);
136+
137+
static constexpr size_t kMaxFrames = 512;
138+
139+
private:
140+
nanobind::ft_mutex mutex;
141+
bool locTracebackEnabled_ = false;
142+
size_t locTracebackFramesLimit_ = 10;
143+
std::unordered_set<std::string> userTracebackIncludeFiles;
144+
std::unordered_set<std::string> userTracebackExcludeFiles;
145+
std::regex userTracebackIncludeRegex;
146+
bool rebuildUserTracebackIncludeRegex = false;
147+
std::regex userTracebackExcludeRegex;
148+
bool rebuildUserTracebackExcludeRegex = false;
149+
llvm::StringMap<bool> isUserTracebackFilenameCache;
150+
};
151+
152+
TracebackLoc &getTracebackLoc() { return tracebackLoc; }
153+
117154
private:
118155
static PyGlobals *instance;
119156

@@ -134,6 +171,8 @@ class PyGlobals {
134171
/// Set of dialect namespaces that we have attempted to import implementation
135172
/// modules for.
136173
llvm::StringSet<> loadedDialectModules;
174+
175+
TracebackLoc tracebackLoc;
137176
};
138177

139178
} // namespace python

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
#include "nanobind/nanobind.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/SmallVector.h"
23-
#include "llvm/Support/raw_ostream.h"
2423

2524
#include <optional>
26-
#include <system_error>
27-
#include <utility>
2825

2926
namespace nb = nanobind;
3027
using namespace nb::literals;
@@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
15231520
llvm::ArrayRef<MlirValue> operands,
15241521
std::optional<nb::dict> attributes,
15251522
std::optional<std::vector<PyBlock *>> successors,
1526-
int regions, DefaultingPyLocation location,
1523+
int regions, PyLocation &location,
15271524
const nb::object &maybeIp, bool inferType) {
15281525
llvm::SmallVector<MlirType, 4> mlirResults;
15291526
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
16271624
if (!operation.ptr)
16281625
throw nb::value_error("Operation creation failed");
16291626
PyOperationRef created =
1630-
PyOperation::createDetached(location->getContext(), operation);
1627+
PyOperation::createDetached(location.getContext(), operation);
16311628
maybeInsertOperation(created, maybeIp);
16321629

16331630
return created.getObject();
@@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
19371934
std::optional<nb::list> resultTypeList, nb::list operandList,
19381935
std::optional<nb::dict> attributes,
19391936
std::optional<std::vector<PyBlock *>> successors,
1940-
std::optional<int> regions, DefaultingPyLocation location,
1937+
std::optional<int> regions, PyLocation &location,
19411938
const nb::object &maybeIp) {
1942-
PyMlirContextRef context = location->getContext();
1939+
PyMlirContextRef context = location.getContext();
19431940

19441941
// Class level operation construction metadata.
19451942
// Operand and result segment specs are either none, which does no
@@ -2789,6 +2786,90 @@ class PyOpAttributeMap {
27892786
PyOperationRef operation;
27902787
};
27912788

2789+
MlirLocation tracebackToLocation(MlirContext ctx) {
2790+
size_t framesLimit =
2791+
PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
2792+
// Use a thread_local here to avoid requiring a large amount of space.
2793+
thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
2794+
frames;
2795+
size_t count = 0;
2796+
2797+
nb::gil_scoped_acquire acquire;
2798+
PyThreadState *tstate = PyThreadState_GET();
2799+
PyFrameObject *next;
2800+
PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
2801+
// In the increment expression:
2802+
// 1. get the next prev frame;
2803+
// 2. decrement the ref count on the current frame (in order that it can get
2804+
// gc'd, along with any objects in its closure and etc);
2805+
// 3. set current = next.
2806+
for (; pyFrame != nullptr && count < framesLimit;
2807+
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
2808+
PyCodeObject *code = PyFrame_GetCode(pyFrame);
2809+
auto fileNameStr =
2810+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
2811+
llvm::StringRef fileName(fileNameStr);
2812+
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
2813+
continue;
2814+
2815+
#if PY_VERSION_HEX < 0x030b00f0
2816+
std::string name =
2817+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
2818+
llvm::StringRef funcName(name);
2819+
int startLine = PyFrame_GetLineNumber(pyFrame);
2820+
MlirLocation loc =
2821+
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
2822+
#else
2823+
// co_qualname and PyCode_Addr2Location added in py3.11
2824+
std::string name =
2825+
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
2826+
llvm::StringRef funcName(name);
2827+
int startLine, startCol, endLine, endCol;
2828+
int lasti = PyFrame_GetLasti(pyFrame);
2829+
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
2830+
&endCol)) {
2831+
throw nb::python_error();
2832+
}
2833+
MlirLocation loc = mlirLocationFileLineColRangeGet(
2834+
ctx, wrap(fileName), startLine, startCol, endLine, endCol);
2835+
#endif
2836+
2837+
frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
2838+
++count;
2839+
}
2840+
// When the loop breaks (after the last iter), current frame (if non-null)
2841+
// is leaked without this.
2842+
Py_XDECREF(pyFrame);
2843+
2844+
if (count == 0)
2845+
return mlirLocationUnknownGet(ctx);
2846+
2847+
MlirLocation callee = frames[0];
2848+
assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
2849+
if (count == 1)
2850+
return callee;
2851+
2852+
MlirLocation caller = frames[count - 1];
2853+
assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
2854+
for (int i = count - 2; i >= 1; i--)
2855+
caller = mlirLocationCallSiteGet(frames[i], caller);
2856+
2857+
return mlirLocationCallSiteGet(callee, caller);
2858+
}
2859+
2860+
PyLocation
2861+
maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
2862+
if (location.has_value())
2863+
return location.value();
2864+
if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
2865+
return DefaultingPyLocation::resolve();
2866+
2867+
PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
2868+
MlirLocation mlirLoc = tracebackToLocation(ctx.get());
2869+
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
2870+
return {ref, mlirLoc};
2871+
}
2872+
27922873
} // namespace
27932874

27942875
//------------------------------------------------------------------------------
@@ -3052,10 +3133,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
30523133
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
30533134
.def_prop_ro_static(
30543135
"current",
3055-
[](nb::object & /*class*/) {
3136+
[](nb::object & /*class*/) -> std::optional<PyLocation *> {
30563137
auto *loc = PyThreadContextEntry::getDefaultLocation();
30573138
if (!loc)
3058-
throw nb::value_error("No current Location");
3139+
return std::nullopt;
30593140
return loc;
30603141
},
30613142
"Gets the Location bound to the current thread or raises ValueError")
@@ -3240,8 +3321,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32403321
kModuleParseDocstring)
32413322
.def_static(
32423323
"create",
3243-
[](DefaultingPyLocation loc) {
3244-
MlirModule module = mlirModuleCreateEmpty(loc);
3324+
[](const std::optional<PyLocation> &loc) {
3325+
PyLocation pyLoc = maybeGetTracebackLocation(loc);
3326+
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
32453327
return PyModule::forModule(module).releaseObject();
32463328
},
32473329
nb::arg("loc").none() = nb::none(), "Creates an empty module")
@@ -3462,8 +3544,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34623544
std::optional<std::vector<PyValue *>> operands,
34633545
std::optional<nb::dict> attributes,
34643546
std::optional<std::vector<PyBlock *>> successors, int regions,
3465-
DefaultingPyLocation location, const nb::object &maybeIp,
3466-
bool inferType) {
3547+
const std::optional<PyLocation> &location,
3548+
const nb::object &maybeIp, bool inferType) {
34673549
// Unpack/validate operands.
34683550
llvm::SmallVector<MlirValue, 4> mlirOperands;
34693551
if (operands) {
@@ -3475,8 +3557,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34753557
}
34763558
}
34773559

3560+
PyLocation pyLoc = maybeGetTracebackLocation(location);
34783561
return PyOperation::create(name, results, mlirOperands, attributes,
3479-
successors, regions, location, maybeIp,
3562+
successors, regions, pyLoc, maybeIp,
34803563
inferType);
34813564
},
34823565
nb::arg("name"), nb::arg("results").none() = nb::none(),
@@ -3520,12 +3603,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35203603
std::optional<nb::list> resultTypeList, nb::list operandList,
35213604
std::optional<nb::dict> attributes,
35223605
std::optional<std::vector<PyBlock *>> successors,
3523-
std::optional<int> regions, DefaultingPyLocation location,
3606+
std::optional<int> regions,
3607+
const std::optional<PyLocation> &location,
35243608
const nb::object &maybeIp) {
3609+
PyLocation pyLoc = maybeGetTracebackLocation(location);
35253610
new (self) PyOpView(PyOpView::buildGeneric(
35263611
name, opRegionSpec, operandSegmentSpecObj,
35273612
resultSegmentSpecObj, resultTypeList, operandList,
3528-
attributes, successors, regions, location, maybeIp));
3613+
attributes, successors, regions, pyLoc, maybeIp));
35293614
},
35303615
nb::arg("name"), nb::arg("opRegionSpec"),
35313616
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3559,17 +3644,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
35593644
[](nb::handle cls, std::optional<nb::list> resultTypeList,
35603645
nb::list operandList, std::optional<nb::dict> attributes,
35613646
std::optional<std::vector<PyBlock *>> successors,
3562-
std::optional<int> regions, DefaultingPyLocation location,
3647+
std::optional<int> regions, std::optional<PyLocation> location,
35633648
const nb::object &maybeIp) {
35643649
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
35653650
std::tuple<int, bool> opRegionSpec =
35663651
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
35673652
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
35683653
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3654+
PyLocation pyLoc = maybeGetTracebackLocation(location);
35693655
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
35703656
resultSegmentSpec, resultTypeList,
35713657
operandList, attributes, successors,
3572-
regions, location, maybeIp);
3658+
regions, pyLoc, maybeIp);
35733659
},
35743660
nb::arg("cls"), nb::arg("results").none() = nb::none(),
35753661
nb::arg("operands").none() = nb::none(),

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
#include "Globals.h"
1515
#include "NanobindUtils.h"
16+
#include "mlir-c/Bindings/Python/Interop.h"
1617
#include "mlir-c/Support.h"
1718
#include "mlir/Bindings/Python/Nanobind.h"
18-
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1919

2020
namespace nb = nanobind;
2121
using namespace mlir;
@@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
197197
// Not found and loading did not yield a registration.
198198
return std::nullopt;
199199
}
200+
201+
bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
202+
nanobind::ft_lock_guard lock(mutex);
203+
return locTracebackEnabled_;
204+
}
205+
206+
void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
207+
nanobind::ft_lock_guard lock(mutex);
208+
locTracebackEnabled_ = value;
209+
}
210+
211+
size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
212+
nanobind::ft_lock_guard lock(mutex);
213+
return locTracebackFramesLimit_;
214+
}
215+
216+
void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
217+
nanobind::ft_lock_guard lock(mutex);
218+
locTracebackFramesLimit_ = std::min(value, kMaxFrames);
219+
}
220+
221+
void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
222+
const std::string &file) {
223+
nanobind::ft_lock_guard lock(mutex);
224+
auto reg = "^" + llvm::Regex::escape(file);
225+
if (userTracebackIncludeFiles.insert(reg).second)
226+
rebuildUserTracebackIncludeRegex = true;
227+
if (userTracebackExcludeFiles.count(reg)) {
228+
if (userTracebackExcludeFiles.erase(reg))
229+
rebuildUserTracebackExcludeRegex = true;
230+
}
231+
}
232+
233+
void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
234+
const std::string &file) {
235+
nanobind::ft_lock_guard lock(mutex);
236+
auto reg = "^" + llvm::Regex::escape(file);
237+
if (userTracebackExcludeFiles.insert(reg).second)
238+
rebuildUserTracebackExcludeRegex = true;
239+
if (userTracebackIncludeFiles.count(reg)) {
240+
if (userTracebackIncludeFiles.erase(reg))
241+
rebuildUserTracebackIncludeRegex = true;
242+
}
243+
}
244+
245+
bool PyGlobals::TracebackLoc::isUserTracebackFilename(
246+
const llvm::StringRef file) {
247+
nanobind::ft_lock_guard lock(mutex);
248+
if (rebuildUserTracebackIncludeRegex) {
249+
userTracebackIncludeRegex.assign(
250+
llvm::join(userTracebackIncludeFiles, "|"));
251+
rebuildUserTracebackIncludeRegex = false;
252+
isUserTracebackFilenameCache.clear();
253+
}
254+
if (rebuildUserTracebackExcludeRegex) {
255+
userTracebackExcludeRegex.assign(
256+
llvm::join(userTracebackExcludeFiles, "|"));
257+
rebuildUserTracebackExcludeRegex = false;
258+
isUserTracebackFilenameCache.clear();
259+
}
260+
if (!isUserTracebackFilenameCache.contains(file)) {
261+
std::string fileStr = file.str();
262+
bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
263+
bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
264+
isUserTracebackFilenameCache[file] = include || !exclude;
265+
}
266+
return isUserTracebackFilenameCache[file];
267+
}

0 commit comments

Comments
 (0)