Skip to content
39 changes: 39 additions & 0 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include <optional>
#include <regex>
#include <string>
#include <unordered_set>
#include <vector>

#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Regex.h"

namespace mlir {
namespace python {
Expand Down Expand Up @@ -114,6 +118,39 @@ class PyGlobals {
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);

class TracebackLoc {
public:
bool locTracebacksEnabled();

void setLocTracebacksEnabled(bool value);

size_t locTracebackFramesLimit();

void setLocTracebackFramesLimit(size_t value);

void registerTracebackFileInclusion(const std::string &file);

void registerTracebackFileExclusion(const std::string &file);

bool isUserTracebackFilename(llvm::StringRef file);

static constexpr size_t kMaxFrames = 512;

private:
nanobind::ft_mutex mutex;
bool locTracebackEnabled_ = false;
size_t locTracebackFramesLimit_ = 10;
std::unordered_set<std::string> userTracebackIncludeFiles;
std::unordered_set<std::string> userTracebackExcludeFiles;
std::regex userTracebackIncludeRegex;
bool rebuildUserTracebackIncludeRegex = false;
std::regex userTracebackExcludeRegex;
bool rebuildUserTracebackExcludeRegex = false;
llvm::StringMap<bool> isUserTracebackFilenameCache;
};

TracebackLoc &getTracebackLoc() { return tracebackLoc; }

private:
static PyGlobals *instance;

Expand All @@ -134,6 +171,8 @@ class PyGlobals {
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;

TracebackLoc tracebackLoc;
};

} // namespace python
Expand Down
122 changes: 104 additions & 18 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>
#include <system_error>
#include <utility>

namespace nb = nanobind;
using namespace nb::literals;
Expand Down Expand Up @@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
llvm::ArrayRef<MlirValue> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation location,
int regions, PyLocation &location,
const nb::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
Expand Down Expand Up @@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
if (!operation.ptr)
throw nb::value_error("Operation creation failed");
PyOperationRef created =
PyOperation::createDetached(location->getContext(), operation);
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);

return created.getObject();
Expand Down Expand Up @@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
std::optional<int> regions, PyLocation &location,
const nb::object &maybeIp) {
PyMlirContextRef context = location->getContext();
PyMlirContextRef context = location.getContext();

// Class level operation construction metadata.
// Operand and result segment specs are either none, which does no
Expand Down Expand Up @@ -2789,6 +2786,90 @@ class PyOpAttributeMap {
PyOperationRef operation;
};

MlirLocation tracebackToLocation(MlirContext ctx) {
size_t framesLimit =
PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
// Use a thread_local here to avoid requiring a large amount of space.
thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
frames;
size_t count = 0;

nb::gil_scoped_acquire acquire;
PyThreadState *tstate = PyThreadState_GET();
PyFrameObject *next;
PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
// In the increment expression:
// 1. get the next prev frame;
// 2. decrement the ref count on the current frame (in order that it can get
// gc'd, along with any objects in its closure and etc);
// 3. set current = next.
for (; pyFrame != nullptr && count < framesLimit;
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
PyCodeObject *code = PyFrame_GetCode(pyFrame);
auto fileNameStr =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
llvm::StringRef fileName(fileNameStr);
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
continue;

#if PY_VERSION_HEX < 0x030b00f0
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
llvm::StringRef funcName(name);
int startLine = PyFrame_GetLineNumber(pyFrame);
MlirLocation loc =
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
#else
// co_qualname and PyCode_Addr2Location added in py3.11
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
llvm::StringRef funcName(name);
int startLine, startCol, endLine, endCol;
int lasti = PyFrame_GetLasti(pyFrame);
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
&endCol)) {
throw nb::python_error();
}
MlirLocation loc = mlirLocationFileLineColRangeGet(
ctx, wrap(fileName), startLine, startCol, endLine, endCol);
#endif

frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
++count;
}
// When the loop breaks (after the last iter), current frame (if non-null)
// is leaked without this.
Py_XDECREF(pyFrame);

if (count == 0)
return mlirLocationUnknownGet(ctx);

MlirLocation callee = frames[0];
assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
if (count == 1)
return callee;

MlirLocation caller = frames[count - 1];
assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
for (int i = count - 2; i >= 1; i--)
caller = mlirLocationCallSiteGet(frames[i], caller);

return mlirLocationCallSiteGet(callee, caller);
}

PyLocation
maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
if (location.has_value())
return location.value();
if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
return DefaultingPyLocation::resolve();

PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
MlirLocation mlirLoc = tracebackToLocation(ctx.get());
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}

} // namespace

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3052,10 +3133,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
.def_prop_ro_static(
"current",
[](nb::object & /*class*/) {
[](nb::object & /*class*/) -> std::optional<PyLocation *> {
auto *loc = PyThreadContextEntry::getDefaultLocation();
if (!loc)
throw nb::value_error("No current Location");
return std::nullopt;
Comment on lines +3136 to +3139
Copy link
Contributor Author

@makslevental makslevental Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change Location.current to return None instead of throwing - see above

return loc;
},
"Gets the Location bound to the current thread or raises ValueError")
Expand Down Expand Up @@ -3240,8 +3321,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
[](DefaultingPyLocation loc) {
MlirModule module = mlirModuleCreateEmpty(loc);
[](const std::optional<PyLocation> &loc) {
PyLocation pyLoc = maybeGetTracebackLocation(loc);
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
nb::arg("loc").none() = nb::none(), "Creates an empty module")
Expand Down Expand Up @@ -3454,8 +3536,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<std::vector<PyValue *>> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation location, const nb::object &maybeIp,
bool inferType) {
const std::optional<PyLocation> &location,
const nb::object &maybeIp, bool inferType) {
// Unpack/validate operands.
llvm::SmallVector<MlirValue, 4> mlirOperands;
if (operands) {
Expand All @@ -3467,8 +3549,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
}

PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOperation::create(name, results, mlirOperands, attributes,
successors, regions, location, maybeIp,
successors, regions, pyLoc, maybeIp,
inferType);
},
nb::arg("name"), nb::arg("results").none() = nb::none(),
Expand Down Expand Up @@ -3512,12 +3595,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
std::optional<int> regions,
const std::optional<PyLocation> &location,
const nb::object &maybeIp) {
PyLocation pyLoc = maybeGetTracebackLocation(location);
new (self) PyOpView(PyOpView::buildGeneric(
name, opRegionSpec, operandSegmentSpecObj,
resultSegmentSpecObj, resultTypeList, operandList,
attributes, successors, regions, location, maybeIp));
attributes, successors, regions, pyLoc, maybeIp));
},
nb::arg("name"), nb::arg("opRegionSpec"),
nb::arg("operandSegmentSpecObj").none() = nb::none(),
Expand Down Expand Up @@ -3551,17 +3636,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](nb::handle cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
std::optional<int> regions, std::optional<PyLocation> location,
const nb::object &maybeIp) {
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
std::tuple<int, bool> opRegionSpec =
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
resultSegmentSpec, resultTypeList,
operandList, attributes, successors,
regions, location, maybeIp);
regions, pyLoc, maybeIp);
},
nb::arg("cls"), nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
Expand Down
70 changes: 69 additions & 1 deletion mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

#include "Globals.h"
#include "NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.

namespace nb = nanobind;
using namespace mlir;
Expand Down Expand Up @@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Not found and loading did not yield a registration.
return std::nullopt;
}

bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
nanobind::ft_lock_guard lock(mutex);
return locTracebackEnabled_;
}

void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
nanobind::ft_lock_guard lock(mutex);
locTracebackEnabled_ = value;
}

size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
nanobind::ft_lock_guard lock(mutex);
return locTracebackFramesLimit_;
}

void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
nanobind::ft_lock_guard lock(mutex);
locTracebackFramesLimit_ = std::min(value, kMaxFrames);
}

void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
const std::string &file) {
nanobind::ft_lock_guard lock(mutex);
auto reg = "^" + llvm::Regex::escape(file);
if (userTracebackIncludeFiles.insert(reg).second)
rebuildUserTracebackIncludeRegex = true;
if (userTracebackExcludeFiles.count(reg)) {
if (userTracebackExcludeFiles.erase(reg))
rebuildUserTracebackExcludeRegex = true;
}
}

void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
const std::string &file) {
nanobind::ft_lock_guard lock(mutex);
auto reg = "^" + llvm::Regex::escape(file);
if (userTracebackExcludeFiles.insert(reg).second)
rebuildUserTracebackExcludeRegex = true;
if (userTracebackIncludeFiles.count(reg)) {
if (userTracebackIncludeFiles.erase(reg))
rebuildUserTracebackIncludeRegex = true;
}
}

bool PyGlobals::TracebackLoc::isUserTracebackFilename(
const llvm::StringRef file) {
nanobind::ft_lock_guard lock(mutex);
if (rebuildUserTracebackIncludeRegex) {
userTracebackIncludeRegex.assign(
llvm::join(userTracebackIncludeFiles, "|"));
rebuildUserTracebackIncludeRegex = false;
isUserTracebackFilenameCache.clear();
}
if (rebuildUserTracebackExcludeRegex) {
userTracebackExcludeRegex.assign(
llvm::join(userTracebackExcludeFiles, "|"));
rebuildUserTracebackExcludeRegex = false;
isUserTracebackFilenameCache.clear();
}
if (!isUserTracebackFilenameCache.contains(file)) {
std::string fileStr = file.str();
bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
isUserTracebackFilenameCache[file] = include || !exclude;
}
return isUserTracebackFilenameCache[file];
}
Loading