diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5feed95f96f53..2975e7add7d49 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -9,6 +9,7 @@ #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" +#include "Traceback.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" @@ -1523,7 +1524,7 @@ nb::object PyOperation::create(std::string_view name, llvm::ArrayRef operands, std::optional attributes, std::optional> successors, - int regions, DefaultingPyLocation location, + int regions, PyLocation location, const nb::object &maybeIp, bool inferType) { llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1627,7 +1628,7 @@ nb::object PyOperation::create(std::string_view name, if (!operation.ptr) throw nb::value_error("Operation creation failed"); PyOperationRef created = - PyOperation::createDetached(location->getContext(), operation); + PyOperation::createDetached(location.getContext(), operation); maybeInsertOperation(created, maybeIp); return created.getObject(); @@ -1937,9 +1938,9 @@ nb::object PyOpView::buildGeneric( std::optional resultTypeList, nb::list operandList, std::optional attributes, std::optional> successors, - std::optional regions, DefaultingPyLocation location, + std::optional regions, PyLocation location, const nb::object &maybeIp) { - PyMlirContextRef context = location->getContext(); + PyMlirContextRef context = location.getContext(); // Class level operation construction metadata. // Operand and result segment specs are either none, which does no @@ -3456,6 +3457,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional> successors, int regions, DefaultingPyLocation location, const nb::object &maybeIp, bool inferType) { + ////////////// + std::optional tb = Traceback::Get(); + PyMlirContextRef ctx = location->getContext(); + auto loc = tb->tracebackToLocation(ctx->get()); + PyLocation pyLoc{ctx, loc}; + ////////////// + // Unpack/validate operands. llvm::SmallVector mlirOperands; if (operands) { @@ -3468,7 +3476,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { } return PyOperation::create(name, results, mlirOperands, attributes, - successors, regions, location, maybeIp, + successors, regions, pyLoc, maybeIp, inferType); }, nb::arg("name"), nb::arg("results").none() = nb::none(), @@ -3517,7 +3525,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { new (self) PyOpView(PyOpView::buildGeneric( name, opRegionSpec, operandSegmentSpecObj, resultSegmentSpecObj, resultTypeList, operandList, - attributes, successors, regions, location, maybeIp)); + attributes, successors, regions, *location.get(), maybeIp)); }, nb::arg("name"), nb::arg("opRegionSpec"), nb::arg("operandSegmentSpecObj").none() = nb::none(), @@ -3553,6 +3561,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional> successors, std::optional regions, DefaultingPyLocation location, const nb::object &maybeIp) { + ////////////// + std::optional tb = Traceback::Get(); + PyMlirContextRef ctx = location->getContext(); + auto loc = tb->tracebackToLocation(ctx->get()); + PyLocation pyLoc{ctx, loc}; + ////////////// std::string name = nb::cast(cls.attr("OPERATION_NAME")); std::tuple opRegionSpec = nb::cast>(cls.attr("_ODS_REGIONS")); @@ -3561,7 +3575,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, resultSegmentSpec, resultTypeList, operandList, attributes, successors, - regions, location, maybeIp); + regions, pyLoc, maybeIp); }, nb::arg("cls"), nb::arg("results").none() = nb::none(), nb::arg("operands").none() = nb::none(), diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9c22dea157c06..e21d8660e8434 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -722,7 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { llvm::ArrayRef operands, std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const nanobind::object &ip, + PyLocation location, const nanobind::object &ip, bool inferType); /// Creates an OpView suitable for this operation. @@ -781,7 +781,7 @@ class PyOpView : public PyOperationBase { nanobind::list operandList, std::optional attributes, std::optional> successors, - std::optional regions, DefaultingPyLocation location, + std::optional regions, PyLocation location, const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 6f49431006605..489d8e21a56cd 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,6 +12,7 @@ #include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" +#include "Traceback.h" #include "mlir/Bindings/Python/Nanobind.h" namespace nb = nanobind; @@ -105,6 +106,7 @@ NB_MODULE(_mlir, m) { "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); + BuildTracebackSubmodule(m); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); populateIRCore(irModule); diff --git a/mlir/lib/Bindings/Python/Traceback.cpp b/mlir/lib/Bindings/Python/Traceback.cpp new file mode 100644 index 0000000000000..812f30534bac9 --- /dev/null +++ b/mlir/lib/Bindings/Python/Traceback.cpp @@ -0,0 +1,535 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "Traceback.h" +#include "IRModule.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep + +#include "llvm/ADT/StringExtras.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE + +// Introduced in python 3.10 +#if PY_VERSION_HEX < 0x030a00f0 +PyObject *Py_NewRef(PyObject *o) { + Py_INCREF(o); + return o; +} +#endif + +// bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1 +#if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION) +static inline int PyFrame_GetLasti(PyFrameObject *frame) { +#if PY_VERSION_HEX >= 0x030A00A7 + // bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset, + // not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes) + // instructions. + if (frame->f_lasti < 0) { + return -1; + } + return frame->f_lasti * 2; +#else + return frame->f_lasti; +#endif +} +#endif + +namespace mlir::python { +struct TracebackEntry; +struct TracebackObject; +} // namespace mlir::python +namespace nb = nanobind; + +template <> +struct std::hash { + std::size_t + operator()(const mlir::python::TracebackObject &tb) const noexcept; +}; + +template <> +struct std::hash { + std::size_t + operator()(const mlir::python::TracebackEntry &tbe) const noexcept; +}; + +namespace mlir::python { + +std::atomic traceback_enabled_ = true; + +static constexpr int kMaxFrames = 512; + +PyTypeObject *traceback_type_ = nullptr; + +// Entry in a traceback. Must be POD. +struct TracebackEntry { + TracebackEntry() = default; + TracebackEntry(PyCodeObject *code, int lasti) : code(code), lasti(lasti) {} + PyCodeObject *code; + int lasti; + + bool operator==(const TracebackEntry &other) const { + return code == other.code && lasti == other.lasti; + } + bool operator!=(const TracebackEntry &other) const { + return !operator==(other); + } +}; +static_assert(std::is_trivial_v == true); + +struct TracebackObject { + PyObject_VAR_HEAD; + TracebackEntry frames[]; +}; + +static_assert(sizeof(TracebackObject) % alignof(PyObject) == 0); +static_assert(sizeof(TracebackEntry) % alignof(void *) == 0); + +bool traceback_check(nb::handle o) { + return Py_TYPE(o.ptr()) == traceback_type_; +} + +Py_hash_t traceback_tp_hash(PyObject *o) { + TracebackObject *tb = reinterpret_cast(o); + std::hash hasher{}; + size_t h = hasher(*tb); + Py_hash_t s = llvm::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +PyObject *traceback_tp_richcompare(PyObject *self, PyObject *other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); + } + + if (!traceback_check(other)) { + return Py_NewRef(Py_False); + } + TracebackObject *tb_self = reinterpret_cast(self); + TracebackObject *tb_other = reinterpret_cast(other); + if (Py_SIZE(tb_self) != Py_SIZE(tb_other)) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + for (Py_ssize_t i = 0; i < Py_SIZE(tb_self); ++i) { + if ((tb_self->frames[i] != tb_other->frames[i])) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + } + return Py_NewRef(op == Py_EQ ? Py_True : Py_False); +} + +static void traceback_tp_dealloc(PyObject *self) { + TracebackObject *tb = reinterpret_cast(self); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + Py_XDECREF(tb->frames[i].code); + } + PyTypeObject *tp = Py_TYPE(self); + tp->tp_free((PyObject *)self); + Py_DECREF(tp); +} + +Traceback::Frame DecodeFrame(const TracebackEntry &frame) { + // python 3.11 +#if PY_VERSION_HEX < 0x030b00f0 + PyObject *name = frame.code->co_name; +#else + PyObject *name = frame.code->co_qualname; +#endif + return Traceback::Frame{ + /*file_name=*/nb::borrow(frame.code->co_filename), + /*function_name=*/nb::borrow(name), + /*function_start_line=*/frame.code->co_firstlineno, + /*line_num=*/PyCode_Addr2Line(frame.code, frame.lasti), + }; +} + +std::string traceback_to_string(const TracebackObject *tb) { + std::vector frame_strs; + frame_strs.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + frame_strs.push_back(DecodeFrame(tb->frames[i]).ToString()); + } + return llvm::join(frame_strs, "\n"); +} + +PyObject *traceback_tp_str(PyObject *self) { + TracebackObject *tb = reinterpret_cast(self); + return nb::cast(traceback_to_string(tb)).release().ptr(); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, reinterpret_cast(traceback_tp_hash)}, + {Py_tp_richcompare, reinterpret_cast(traceback_tp_richcompare)}, + {Py_tp_dealloc, reinterpret_cast(traceback_tp_dealloc)}, + {Py_tp_str, reinterpret_cast(traceback_tp_str)}, + {0, nullptr}, +}; + +nb::object AsPythonTraceback(const Traceback &tb) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); + TracebackObject *tb_obj = reinterpret_cast(tb.ptr()); + for (Py_ssize_t i = 0; i < Py_SIZE(tb_obj); ++i) { + const TracebackEntry &frame = tb_obj->frames[i]; + int lineno = PyCode_Addr2Line(frame.code, frame.lasti); + // Under Python 3.11 we observed crashes when using a fake PyFrameObject + // with a real PyCodeObject (https://github.com/google/jax/issues/16027). + // because the frame does not have fields necessary to compute the locals, + // notably the closure object, leading to crashes in CPython in + // _PyFrame_FastToLocalsWithError + // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2 + // We therefore always build a fake code object to go along with our fake + // frame. + PyCodeObject *py_code = + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.code->co_filename), + PyUnicode_AsUTF8(frame.code->co_name), lineno); + PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/nullptr); + Py_DECREF(py_code); + + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/lineno); + } + return traceback; +} + +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + assert(PyGILState_Check()); + std::vector frames; + TracebackObject *tb = reinterpret_cast(ptr()); + frames.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + const TracebackEntry &frame = tb->frames[i]; + frames.push_back(DecodeFrame(frame)); + } + return frames; +} + +MlirLocation Traceback::tracebackToLocation(MlirContext ctx) const { + // We require the GIL because we manipulate Python strings. + assert(PyGILState_Check()); + + // check cache + int frames_limit = 100; + std::vector frame_locs{}; + TracebackObject *tb = reinterpret_cast(ptr()); + frame_locs.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + const TracebackEntry &frame = tb->frames[i]; + // if not _is_user_file(code.co_filename): + // continue + // get_canonical_source_file + MlirStringRef fileName = mlirStringRefCreateFromCString( + nb::borrow(frame.code->co_filename).c_str()); +#if PY_VERSION_HEX < 0x030b00f0 + MlirStringRef funcName = mlirStringRefCreateFromCString( + nb::borrow(frame.code->co_name).c_str()); + auto line = PyCode_Addr2Line(frame.code, frame.lasti); + auto loc = mlirLocationFileLineColGet(ctx, fileName, line, 0); +#else + MlirStringRef funcName = mlirStringRefCreateFromCString( + nb::borrow(frame.code->co_qualname).c_str()); + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(frame.code, frame.lasti, &start_line, + &start_column, &end_line, &end_column)) { + throw nb::python_error(); + } + auto loc = mlirLocationFileLineColRangeGet( + ctx, fileName, start_column, start_column, end_line, end_column); +#endif + frame_locs.push_back(mlirLocationNameGet(ctx, funcName, loc)); + if (frame_locs.size() > frames_limit) + break; + } + + if (frame_locs.empty()) + return mlirLocationUnknownGet(ctx); + if (frame_locs.size() == 1) + return frame_locs.front(); + + MlirLocation callee = frame_locs.front(); + frame_locs.erase(frame_locs.begin()); + MlirLocation caller = frame_locs.back(); + for (const MlirLocation &frame : + llvm::reverse(llvm::ArrayRef(frame_locs).drop_back())) + caller = mlirLocationCallSiteGet(frame, caller); + + return mlirLocationCallSiteGet(callee, caller); +} + +std::string Traceback::Frame::ToString() const { + std::string s = nb::cast(file_name); + s += ":" + std::to_string(line_num) + " "; + s += "(" + nb::cast(function_name) + ")"; + return s; +} + +std::string Traceback::ToString() const { + return traceback_to_string(reinterpret_cast(ptr())); +} + +std::vector> Traceback::RawFrames() const { + const TracebackObject *tb = reinterpret_cast(ptr()); + std::vector> frames; + frames.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + frames.push_back(std::make_pair(tb->frames[i].code, tb->frames[i].lasti)); + } + return frames; +} + +/*static*/ bool Traceback::Check(PyObject *o) { return traceback_check(o); } + +/*static*/ std::optional Traceback::Get() { + // We use a thread_local here mostly to avoid requiring a large amount of + // space. + thread_local std::array frames; + int count = 0; + + assert(PyGILState_Check()); + + if (!traceback_enabled_.load()) { + return std::nullopt; + } + + PyThreadState *thread_state = PyThreadState_GET(); + +#ifdef PLATFORM_GOOGLE +// This code is equivalent to the version using public APIs, but it saves us +// an allocation of one object per stack frame. However, this is definitely +// violating the API contract of CPython, so we only use this where we can be +// confident we know exactly which CPython we are using (internal to Google). +// Feel free to turn this on if you like, but it might break at any time! +#if PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame *f = thread_state->cframe->current_frame; + f != nullptr && count < kMaxFrames; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) + continue; + Py_INCREF(f->f_code); + frames[count] = {f->f_code, static_cast(_PyInterpreterFrame_LASTI(f) * + sizeof(_Py_CODEUNIT))}; + ++count; + } +#else // PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame *f = thread_state->current_frame; + f != nullptr && count < kMaxFrames; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) + continue; + Py_INCREF(f->f_executable); + frames[count] = { + reinterpret_cast(f->f_executable), + static_cast(_PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT))}; + ++count; + } +#endif // PY_VERSION_HEX < 0x030d0000 + +#else // PLATFORM_GOOGLE + PyFrameObject *next; + for (PyFrameObject *py_frame = PyThreadState_GetFrame(thread_state); + py_frame != nullptr && count < kMaxFrames; py_frame = next) { + frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)}; + ++count; + next = PyFrame_GetBack(py_frame); + Py_XDECREF(py_frame); + } +#endif // PLATFORM_GOOGLE + + Traceback traceback = + nb::steal(PyObject_NewVar(PyObject, traceback_type_, count)); + TracebackObject *tb = reinterpret_cast(traceback.ptr()); + std::memcpy(tb->frames, frames.data(), sizeof(TracebackEntry) * count); + return traceback; +} + +template +nanobind::object nb_property_readonly(Func &&get) { + nanobind::handle property(reinterpret_cast(&PyProperty_Type)); + return property(nanobind::cpp_function(std::forward(get)), + nanobind::none(), nanobind::none(), ""); +} + +void BuildTracebackSubmodule(nb::module_ &m) { + nb::class_(m, "Frame") + .def(nb::init()) + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) + .def("__repr__", [](const Traceback::Frame &frame) { + std::string s = nb::cast(frame.function_name); + s += ";" + nb::cast(frame.file_name); + s += ":" + std::to_string(frame.line_num); + return s; + }); + + std::string name = nb::cast(m.attr("__name__")); + name += ".Traceback"; + + PyType_Spec traceback_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(TracebackObject)), + /*.itemsize=*/static_cast(sizeof(TracebackEntry)), + /*.flags=*/Py_TPFLAGS_DEFAULT, + /*.slots=*/traceback_slots_, + }; + + traceback_type_ = + reinterpret_cast(PyType_FromSpec(&traceback_spec)); + if (!traceback_type_) { + throw nb::python_error(); + } + + auto type = nb::borrow(traceback_type_); + m.attr("Traceback") = type; + + m.def("tracebacks_enabled", []() { return traceback_enabled_.load(); }); + m.def("set_tracebacks_enabled", + [](bool value) { traceback_enabled_.store(value); }); + + type.attr("get_traceback") = nb::cpp_function(Traceback::Get, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` + object that describes the Python stack of the calling thread. Stack + trace collection has a small overhead, so it is disabled by default. If + traceback collection is disabled, returns ``None``. )doc"); + type.attr("_infer_location") = nb::cpp_function( + [](DefaultingPyMlirContext context) { + auto tb = Traceback::Get(); + assert(tb); + return tb->tracebackToLocation(context->get()); + }, + nb::arg("context") = nb::none()); + type.attr("frames") = nb_property_readonly(&Traceback::Frames); + type.attr("raw_frames") = nb::cpp_function( + [](const Traceback &tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything + // rather than one per frame. + std::vector> frames = tb.RawFrames(); + nb::list out_code = nb::steal(PyList_New(frames.size())); + nb::list out_lasti = nb::steal(PyList_New(frames.size())); + for (size_t i = 0; i < frames.size(); ++i) { + const auto &frame = frames[i]; + PyObject *code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }, + nb::is_method()); + type.attr("as_python_traceback") = + nb::cpp_function(AsPythonTraceback, nb::is_method()); + + type.attr("traceback_from_frames") = nb::cpp_function( + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame &frame : frames) { + PyCodeObject *py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + + type.attr("code_addr2line") = nb::cpp_function( + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw std::runtime_error("code argument must be a code object"); + } + return PyCode_Addr2Line(reinterpret_cast(code.ptr()), + lasti); + }, + "Python wrapper around the Python C API function PyCode_Addr2Line"); + +#if PY_VERSION_HEX >= 0x030b00f0 + type.attr("code_addr2location") = nb::cpp_function( + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw std::runtime_error("code argument must be a code object"); + } + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), + lasti, &start_line, &start_column, &end_line, + &end_column)) { + throw nb::python_error(); + } + return nb::make_tuple(start_line, start_column, end_line, end_column); + }, + "Python wrapper around the Python C API function PyCode_Addr2Location"); +#endif +} +} // namespace mlir::python + +std::size_t std::hash::operator()( + const mlir::python::TracebackObject &tb) const noexcept { + const unsigned length = Py_SIZE(&tb); + const mlir::python::TracebackEntry *begin = &tb.frames[0]; + const mlir::python::TracebackEntry *end = begin + length; + const unsigned *VBegin = reinterpret_cast(begin); + const unsigned *VEnd = reinterpret_cast(end); + return llvm::hash_combine(length, llvm::hash_combine_range(VBegin, VEnd)); +} + +std::size_t std::hash::operator()( + const mlir::python::TracebackEntry &tbe) const noexcept { + return llvm::hash_combine(tbe.code, tbe.lasti); +} \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/Traceback.h b/mlir/lib/Bindings/Python/Traceback.h new file mode 100644 index 0000000000000..0dcfedd8b974b --- /dev/null +++ b/mlir/lib/Bindings/Python/Traceback.h @@ -0,0 +1,69 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TRACEBACK_H_ +#define JAXLIB_TRACEBACK_H_ + +#include + +#include +#include +#include +#include + +#include "mlir-c/IR.h" + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace mlir::python { + +class Traceback : public nanobind::object { +public: + NB_OBJECT(Traceback, nanobind::object, "Traceback", Traceback::Check); + + // Returns a traceback if it is enabled, otherwise returns nullopt. + static std::optional Get(); + + // Returns a string representation of the traceback. + std::string ToString() const; + + // Returns a list of (code, lasti) pairs for each frame in the traceback. + std::vector> RawFrames() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + + // Returns a list of Frames for the traceback. + std::vector Frames() const; + + // Returns a list of Frames for the traceback. + MlirLocation tracebackToLocation(MlirContext ctx) const; + +private: + static bool Check(PyObject *o); +}; + +void BuildTracebackSubmodule(nanobind::module_ &m); + +} // namespace mlir::python + +#endif // JAXLIB_TRACEBACK_H_ diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7a0c95ebb8200..243fbe64de900 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -23,6 +23,8 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python passmanager.py rewrite.py dialects/_ods_common.py + source_info_util.py + traceback_util.py # The main _mlir module has submodules: include stubs from each. _mlir_libs/_mlir/__init__.pyi @@ -486,6 +488,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core IRTypes.cpp Pass.cpp Rewrite.cpp + Traceback.cpp # Headers must be included explicitly so they are installed. Globals.h @@ -493,6 +496,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core Pass.h NanobindUtils.h Rewrite.h + Traceback.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS diff --git a/mlir/python/mlir/source_info_util.py b/mlir/python/mlir/source_info_util.py new file mode 100644 index 0000000000000..937f139164491 --- /dev/null +++ b/mlir/python/mlir/source_info_util.py @@ -0,0 +1,495 @@ +# Copyright 2020 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from collections.abc import Iterator +import contextlib +import dataclasses +import functools +import itertools +import os.path +import re +import sysconfig +import threading +import types +from typing import NamedTuple, Optional +from .ir import Location + +# import jax.version +# from jax._src.lib import xla_client + +from . import traceback_util +from .traceback_util import ( + TracebackCaches, + Traceback, + _traceback_caches, + _traceback_in_locations_limit, + _include_full_tracebacks_in_locations, +) + +traceback_util.register_exclusion(__file__) + + +@dataclasses.dataclass(frozen=True) +class Frame: + file_name: str + function_name: str + start_line: int + start_column: Optional[int] = None + end_line: Optional[int] = None + end_column: Optional[int] = None + + +_exclude_paths: list[str] = [ + # Attach the separator to make sure that .../jax does not end up matching + # .../jax_triton and other packages that might have a jax prefix. + # os.path.dirname(jax.version.__file__) + os.sep, + # Also exclude stdlib as user frames. In a non-standard Python runtime, + # the following two may be different. + sysconfig.get_path("stdlib"), + os.path.dirname(sysconfig.__file__), +] + + +@functools.cache +def _exclude_path_regex() -> re.Pattern[str]: + # The regex below would not handle an empty set of exclusions correctly. + assert len(_exclude_paths) > 0 + return re.compile("|".join(f"^{re.escape(path)}" for path in _exclude_paths)) + + +def register_exclusion(path: str): + _exclude_paths.append(path) + _exclude_path_regex.cache_clear() + is_user_filename.cache_clear() + + +# Explicit inclusions take priority over exclude paths. +_include_paths: list[str] = [] + + +@functools.cache +def _include_path_regex() -> re.Pattern[str]: + patterns = [f"^{re.escape(path)}" for path in _include_paths] + patterns.append("_test.py$") + return re.compile("|".join(patterns)) + + +def register_inclusion(path: str): + _include_paths.append(path) + _include_path_regex.cache_clear() + is_user_filename.cache_clear() + + +class Scope(NamedTuple): + name: str + + def wrap(self, stack: list[str]): + stack.append(self.name) + + +class Transform(NamedTuple): + name: str + + def wrap(self, stack: list[str]): + if stack: + stack[-1] = f"{self.name}({stack[-1]})" + else: + stack.append(f"{self.name}()") + + +@dataclasses.dataclass(frozen=True) +class NameStack: + stack: tuple[Scope | Transform, ...] = () + + def extend(self, name: str) -> NameStack: + return NameStack((*self.stack, Scope(name))) + + def transform(self, transform_name: str) -> NameStack: + return NameStack((*self.stack, Transform(transform_name))) + + def __getitem__(self, idx: slice) -> NameStack: + return NameStack(self.stack[idx]) + + def __len__(self): + return len(self.stack) + + def __add__(self, other: NameStack) -> NameStack: + return NameStack(self.stack + other.stack) + + def __radd__(self, other: NameStack) -> NameStack: + return NameStack(other.stack + self.stack) + + def __str__(self) -> str: + scope: list[str] = [] + for elem in self.stack[::-1]: + elem.wrap(scope) + return "/".join(reversed(scope)) + + +def new_name_stack(name: str = "") -> NameStack: + name_stack = NameStack() + if name: + name_stack = name_stack.extend(name) + return name_stack + + +class SourceInfo: + traceback: Traceback | None + name_stack: NameStack + + # It's slightly faster to use a class with __slots__ than a NamedTuple. + __slots__ = ["traceback", "name_stack"] + + def __init__(self, traceback: Traceback | None, name_stack: NameStack): + self.traceback = traceback + self.name_stack = name_stack + + def replace( + self, *, traceback: Traceback | None = None, name_stack: NameStack | None = None + ) -> SourceInfo: + return SourceInfo( + self.traceback if traceback is None else traceback, + self.name_stack if name_stack is None else name_stack, + ) + + +def new_source_info() -> SourceInfo: + return SourceInfo(None, NameStack()) + + +@functools.cache +def is_user_filename(filename: str) -> bool: + """Heuristic that guesses the identity of the user's code in a stack trace.""" + return ( + _include_path_regex().search(filename) is not None + or _exclude_path_regex().search(filename) is None + ) + + +def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: + if sys.version_info.minor >= 11: + loc = Traceback.code_addr2location(code, lasti) + start_line, start_column, end_line, end_column = loc + frame = Frame( + file_name=code.co_filename, + function_name=code.co_qualname, + start_line=start_line, + start_column=start_column, + end_line=end_line, + end_column=end_column, + ) + else: + start_line = Traceback.code_addr2line(code, lasti) + frame = Frame( + file_name=code.co_filename, + function_name=code.co_name, + start_line=start_line, + start_column=0, + ) + + return frame + + +def user_frames(traceback: Traceback | None) -> Iterator[Frame]: + """Iterator over the user's frames, filtering jax-internal frames.""" + # Guess the user's frame is the innermost frame not in the jax source tree or + # Python stdlib. We don't use traceback_util.path_starts_with because that + # incurs filesystem access, which may be slow; we call this function when + # e.g. adding source provenance annotations to XLA lowerings, so we don't + # want to incur the cost. We consider files that end with _test.py as user + # frames, to allow testing this mechanism from tests. + code, lasti = traceback.raw_frames() if traceback else ([], []) + return ( + raw_frame_to_frame(code[i], lasti[i]) + for i in range(len(code)) + if is_user_filename(code[i].co_filename) + ) + + +@functools.lru_cache(maxsize=64) +def user_frame(traceback: Traceback | None) -> Frame | None: + return next(user_frames(traceback), None) + + +def _summarize_frame(frame: Frame) -> str: + if frame.start_column != 0: + return ( + f"{frame.file_name}:{frame.start_line}:{frame.start_column} " + f"({frame.function_name})" + ) + else: + return f"{frame.file_name}:{frame.start_line} ({frame.function_name})" + + +def summarize(source_info: SourceInfo, num_frames=1) -> str: + frames = itertools.islice(user_frames(source_info.traceback), num_frames) + frame_strs = [_summarize_frame(frame) if frame else "unknown" for frame in frames] + return "\n".join(reversed(frame_strs)) + + +class _SourceInfoContext(threading.local): + context: SourceInfo + + def __init__(self): + super().__init__() + self.context = new_source_info() + + +_source_info_context = _SourceInfoContext() + + +def current() -> SourceInfo: + source_info = _source_info_context.context + if not source_info.traceback: + source_info = source_info.replace(traceback=Traceback.get_traceback()) + return source_info + + +class JaxStackTraceBeforeTransformation(Exception): + pass + + +_message = ( + "The preceding stack trace is the source of the JAX operation that, once " + "transformed by JAX, triggered the following exception.\n" + "\n--------------------" +) + + +def has_user_context(e): + while e is not None: + if isinstance(e, JaxStackTraceBeforeTransformation): + return True + e = e.__cause__ + return False + + +class UserContextManager: + __slots__ = ["traceback", "name_stack", "prev"] + + def __init__( + self, traceback: Traceback | None, *, name_stack: NameStack | None = None + ): + self.traceback = traceback + self.name_stack = name_stack + + def __enter__(self): + self.prev = _source_info_context.context + _source_info_context.context = _source_info_context.context.replace( + traceback=self.traceback, name_stack=self.name_stack + ) + + def __exit__(self, exc_type, exc_value, traceback): + _source_info_context.context = self.prev + if exc_type is None or exc_value is None: + return + + if self.traceback is None or has_user_context(exc_value): + return + + filtered_tb = traceback_util.filter_traceback( + self.traceback.as_python_traceback() + ) + if filtered_tb: + msg = traceback_util.format_exception_only(exc_value) + msg = f"{msg}\n\n{_message}" + exp = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb) + exp.__context__ = exc_value.__context__ + exp.__cause__ = exc_value.__cause__ + exp.__suppress_context__ = exc_value.__suppress_context__ + exc_value.__context__ = None + exc_value.__cause__ = exp + + +user_context = UserContextManager + + +def current_name_stack() -> NameStack: + return _source_info_context.context.name_stack + + +class ExtendNameStackContextManager(contextlib.ContextDecorator): + __slots__ = ["name", "prev"] + + def __init__(self, name: str): + self.name = name + + def __enter__(self): + self.prev = prev = _source_info_context.context + name_stack = prev.name_stack.extend(self.name) + _source_info_context.context = prev.replace(name_stack=name_stack) + return name_stack + + def __exit__(self, exc_type, exc_value, traceback): + _source_info_context.context = self.prev + + +extend_name_stack = ExtendNameStackContextManager + + +class SetNameStackContextManager(contextlib.ContextDecorator): + __slots__ = ["name_stack", "prev"] + + def __init__(self, name_stack: NameStack): + self.name_stack = name_stack + + def __enter__(self): + self.prev = prev = _source_info_context.context + _source_info_context.context = prev.replace(name_stack=self.name_stack) + + def __exit__(self, exc_type, exc_value, traceback): + _source_info_context.context = self.prev + + +set_name_stack = SetNameStackContextManager + + +# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack +# implementation doesn't work. Luckily this context manager isn't called much so +# the performance shouldn't matter. See blame commit message for repro. +# reset_name_stack = lambda: SetNameStackContextManager(NameStack()) +@contextlib.contextmanager +def reset_name_stack() -> Iterator[None]: + with set_name_stack(NameStack()): + yield + + +class TransformNameStackContextManager(contextlib.ContextDecorator): + __slots__ = ["name", "prev"] + + def __init__(self, name: str): + self.name = name + + def __enter__(self): + self.prev = prev = _source_info_context.context + name_stack = prev.name_stack.transform(self.name) + _source_info_context.context = prev.replace(name_stack=name_stack) + return name_stack + + def __exit__(self, exc_type, exc_value, traceback): + _source_info_context.context = self.prev + + +transform_name_stack = TransformNameStackContextManager + + +def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str: + canonical_file_name = caches.canonical_name_cache.get(file_name, None) + if canonical_file_name is not None: + return canonical_file_name + + # pattern = config.hlo_source_file_canonicalization_regex.value + # if pattern: + # file_name = re.sub(pattern, "", file_name) + caches.canonical_name_cache[file_name] = file_name + return file_name + + +def _is_user_file(file_name: str) -> bool: + is_user = _traceback_caches.is_user_file_cache.get(file_name, None) + if is_user is not None: + return is_user + out = is_user_filename(file_name) + _traceback_caches.is_user_file_cache[file_name] = out + return out + + +# def _traceback_to_location(tb: Traceback) -> Location: +# """Converts a full traceback to a callsite() MLIR location.""" +# loc = _traceback_caches.traceback_cache.get(tb, None) +# if loc is not None: +# return loc +# +# frame_locs = [] +# frames_limit = _traceback_in_locations_limit +# frames_limit = frames_limit if frames_limit >= 0 else 1000 +# +# codes, lastis = tb.raw_frames() +# for _, code in enumerate(codes): +# if not _is_user_file(code.co_filename): +# continue +# +# lasti = lastis[i] +# code_lasti = code, lasti +# loc = _traceback_caches.location_cache.get(code_lasti, None) +# if loc is None: +# frame = raw_frame_to_frame(code, lasti) +# if ( +# frame.start_column is not None +# and frame.end_line is not None +# and frame.end_column is not None +# ): +# file_loc = Location.file( +# get_canonical_source_file(frame.file_name, _traceback_caches), +# frame.start_line, +# frame.start_column, +# frame.end_line, +# frame.end_column, +# ) +# else: +# file_loc = Location.file( +# get_canonical_source_file(frame.file_name, _traceback_caches), +# frame.start_line, +# frame.start_column, +# ) +# loc = Location.name(frame.function_name, childLoc=file_loc) +# _traceback_caches.location_cache[code_lasti] = loc +# frame_locs.append(loc) +# if len(frame_locs) >= frames_limit: +# break +# +# n = len(frame_locs) +# if n == 0: +# loc = Location.unknown() +# elif n == 1: +# loc = frame_locs[0] +# else: +# loc = Location.callsite(frame_locs[0], frame_locs[1:]) +# _traceback_caches.traceback_cache[tb] = loc +# return loc + + +def source_info_to_location( + primitive: None, + name_stack: NameStack, + traceback: Traceback | None, +) -> Location: + if _include_full_tracebacks_in_locations: + if traceback is None: + loc = Location.unknown() + else: + loc = _traceback_to_location(traceback) + else: + frame = user_frame(traceback) + if frame is None: + loc = Location.unknown() + else: + loc = Location.file( + get_canonical_source_file(frame.file_name, _traceback_caches), + frame.start_line, + frame.start_column, + ) + if primitive is None: + if name_stack.stack: + loc = Location.name(str(name_stack), childLoc=loc) + else: + eqn_str = ( + f"{name_stack}/{primitive.name}" if name_stack.stack else primitive.name + ) + loc = Location.name(eqn_str, childLoc=loc) + loc = Location.name(f"{primitive.name}:", childLoc=loc) + return loc diff --git a/mlir/python/mlir/traceback_util.py b/mlir/python/mlir/traceback_util.py new file mode 100644 index 0000000000000..893fb43588dfb --- /dev/null +++ b/mlir/python/mlir/traceback_util.py @@ -0,0 +1,262 @@ +# Copyright 2020 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from collections.abc import Callable +import functools +import os +import traceback +import types +from typing import Any, TypeVar, cast +from ._mlir_libs._mlir import Traceback, set_tracebacks_enabled +from .ir import Location + + +set_tracebacks_enabled(True) + +C = TypeVar("C", bound=Callable[..., Any]) + +_exclude_paths: list[str] = [__file__] + + +def register_exclusion(path: str): + _exclude_paths.append(path) + + +_jax_message_append = ( + "The stack trace below excludes JAX-internal frames.\n" + "The preceding is the original exception that occurred, unmodified.\n" + "\n--------------------" +) + + +def _path_starts_with(path: str, path_prefix: str) -> bool: + path = os.path.abspath(path) + path_prefix = os.path.abspath(path_prefix) + try: + common = os.path.commonpath([path, path_prefix]) + except ValueError: + # path and path_prefix are both absolute, the only case will raise a + # ValueError is different drives. + # https://docs.python.org/3/library/os.path.html#os.path.commonpath + return False + try: + return common == path_prefix or os.path.samefile(common, path_prefix) + except OSError: + # One of the paths may not exist. + return False + + +def include_frame(f: types.FrameType) -> bool: + return include_filename(f.f_code.co_filename) + + +def include_filename(filename: str) -> bool: + return not any(_path_starts_with(filename, path) for path in _exclude_paths) + + +# When scanning stack traces, we might encounter frames from cpython that are +# removed from printed stack traces, such as frames from parts of importlib. We +# ignore these frames heuristically based on source and name match. +def _ignore_known_hidden_frame(f: types.FrameType) -> bool: + return "importlib._bootstrap" in f.f_code.co_filename + + +def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType): + for f, _lineno in traceback.walk_tb(tb): + if not include_frame(f): + f.f_locals["__tracebackhide__"] = True + + +def filter_traceback(tb: types.TracebackType) -> types.TracebackType | None: + out = None + # Scan the traceback and collect relevant frames. + frames = list(traceback.walk_tb(tb)) + for f, lineno in reversed(frames): + if include_frame(f): + out = types.TracebackType(out, f, f.f_lasti, lineno) + return out + + +def _add_call_stack_frames(tb: types.TracebackType) -> types.TracebackType: + # Continue up the call stack. + # + # We would like to avoid stepping too far up, e.g. past the exec/eval point of + # a REPL such as IPython. To that end, we stop past the first contiguous bunch + # of module-level frames, if we reach any such frames at all. This is a + # heuristic that might stop in advance of the REPL boundary. For example, if + # the call stack includes module-level frames from the current module A, and + # the current module A was imported from within a function F elsewhere, then + # the stack trace we produce will be truncated at F's frame. + out = tb + + reached_module_level = False + for f, lineno in traceback.walk_stack(tb.tb_frame): + if _ignore_known_hidden_frame(f): + continue + if reached_module_level and f.f_code.co_name != "": + break + if include_frame(f): + out = types.TracebackType(out, f, f.f_lasti, lineno) + if f.f_code.co_name == "": + reached_module_level = True + return out + + +def _is_reraiser_frame(f: traceback.FrameSummary) -> bool: + return f.filename == __file__ and f.name == "reraise_with_filtered_traceback" + + +def _is_under_reraiser(e: BaseException) -> bool: + if e.__traceback__ is None: + return False + tb = traceback.extract_stack(e.__traceback__.tb_frame) + return any(_is_reraiser_frame(f) for f in tb[:-1]) + + +def format_exception_only(e: BaseException) -> str: + return "".join(traceback.format_exception_only(type(e), e)).strip() + + +class UnfilteredStackTrace(Exception): + pass + + +_simplified_tb_msg = ( + "For simplicity, JAX has removed its internal frames from the " + "traceback of the following exception. Set " + "JAX_TRACEBACK_FILTERING=off to include these." +) + + +class SimplifiedTraceback(Exception): + def __str__(self): + return _simplified_tb_msg + + +SimplifiedTraceback.__module__ = "jax.errors" + + +def _running_under_ipython() -> bool: + """Returns true if we appear to be in an IPython session.""" + try: + get_ipython() # type: ignore + return True + except NameError: + return False + + +def _ipython_supports_tracebackhide() -> bool: + """Returns true if the IPython version supports __tracebackhide__.""" + import IPython # pytype: disable=import-error + + return IPython.version_info[:2] >= (7, 17) + + +def _filtering_mode() -> str: + mode = None + if mode is None or mode == "auto": + if _running_under_ipython() and _ipython_supports_tracebackhide(): + mode = "tracebackhide" + else: + mode = "quiet_remove_frames" + return mode + + +def api_boundary(fun: C) -> C: + """Wraps ``fun`` to form a boundary for filtering exception tracebacks. + + When an exception occurs below ``fun``, this appends to it a custom + ``__cause__`` that carries a filtered traceback. The traceback imitates the + stack trace of the original exception, but with JAX-internal frames removed. + + This boundary annotation works in composition with itself. The topmost frame + corresponding to an :func:`~api_boundary` is the one below which stack traces + are filtered. In other words, if ``api_boundary(f)`` calls + ``api_boundary(g)``, directly or indirectly, the filtered stack trace provided + is the same as if ``api_boundary(f)`` were to simply call ``g`` instead. + + This annotation is primarily useful in wrapping functions output by JAX's + transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is + called, JAX's JIT compilation machinery is invoked, which in turn calls ``f`` + in order to trace and translate it. If the function ``f`` raises an exception, + the stack unwinds through JAX's JIT internals up to the original call site of + ``g``. Because the function returned by :func:`~jax.jit` is annotated as an + :func:`~api_boundary`, such an exception is accompanied by an additional + traceback that excludes the frames specific to JAX's implementation. + """ + + @functools.wraps(fun) + def reraise_with_filtered_traceback(*args, **kwargs): + __tracebackhide__ = True + try: + return fun(*args, **kwargs) + except Exception as e: + mode = _filtering_mode() + if _is_under_reraiser(e) or mode == "off": + raise + if mode == "tracebackhide": + _add_tracebackhide_to_hidden_frames(e.__traceback__) + raise + + filtered_tb, unfiltered = None, None + try: + tb = e.__traceback__ + filtered_tb = filter_traceback(tb) + e.with_traceback(filtered_tb) + if mode == "quiet_remove_frames": + e.add_note("--------------------\n" + _simplified_tb_msg) + else: + if mode == "remove_frames": + msg = format_exception_only(e) + msg = f"{msg}\n\n{_jax_message_append}" + jax_error = UnfilteredStackTrace(msg) + jax_error.with_traceback(_add_call_stack_frames(tb)) + else: + raise ValueError( + f"JAX_TRACEBACK_FILTERING={mode} is not a valid value." + ) + jax_error.__cause__ = e.__cause__ + jax_error.__context__ = e.__context__ + jax_error.__suppress_context__ = e.__suppress_context__ + e.__cause__ = jax_error + e.__context__ = None + raise + finally: + del filtered_tb + del unfiltered + del mode + + return cast(C, reraise_with_filtered_traceback) + + +@dataclasses.dataclass +class TracebackCaches: + traceback_cache: dict[Traceback, Location] + location_cache: dict[tuple[types.CodeType, int], Location] + canonical_name_cache: dict[str, str] + is_user_file_cache: dict[str, bool] + + def __init__(self): + self.traceback_cache = {} + self.location_cache = {} + self.canonical_name_cache = {} + self.is_user_file_cache = {} + + +_traceback_caches = TracebackCaches() +_traceback_in_locations_limit = 100 +_include_full_tracebacks_in_locations = True diff --git a/mlir/test/python/ir/line_info.py b/mlir/test/python/ir/line_info.py new file mode 100644 index 0000000000000..6c0eea255d2df --- /dev/null +++ b/mlir/test/python/ir/line_info.py @@ -0,0 +1,39 @@ +# RUN: %PYTHON %s | FileCheck %s +import gc +import traceback + +from mlir import source_info_util +from mlir.source_info_util import _traceback_to_location +from mlir import traceback_util +from mlir.ir import Context + +# CHECK: hello +print("hello") + + +# traceback_util.register_exclusion(__file__) + + +def run(f): + print("\nTEST:", f.__name__) + with Context() as ctx: + f() + gc.collect() + # assert Context._get_live_count() == 0 + return f + + +@run +def foo(): + def bar(): + curr = source_info_util.current() + print(curr.name_stack) + print(curr.traceback) + traceback.print_tb( + traceback_util.filter_traceback(curr.traceback.as_python_traceback()) + ) + + loc = _traceback_to_location(curr.traceback) + print(loc) + + bar()