forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathopaque_obj.h
More file actions
79 lines (69 loc) · 2.81 KB
/
opaque_obj.h
File metadata and controls
79 lines (69 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#pragma once
#include <string>
#include <utility>
#include <c10/macros/Macros.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/custom_class.h>
namespace torch::jit {
struct OpaqueObject : public CustomClassHolder {
OpaqueObject(py::object payload) : payload_(std::move(payload)) {}
void setPayload(py::object payload) {
payload_ = std::move(payload);
}
py::object getPayload() {
return payload_;
}
py::object payload_;
};
static auto register_opaque_obj_class =
torch::class_<OpaqueObject>("aten", "OpaqueObject")
.def(
"__eq__",
[](const c10::intrusive_ptr<OpaqueObject>& self,
const c10::intrusive_ptr<OpaqueObject>& other) {
auto self_payload = self->getPayload();
auto other_payload = other->getPayload();
if (!self_payload.ptr() || !other_payload.ptr()) {
return false;
}
py::gil_scoped_acquire gil;
auto res = PyObject_RichCompareBool(
self_payload.ptr(), other_payload.ptr(), Py_EQ);
if (res == -1) {
throw py::error_already_set();
}
return res > 0;
})
.def_pickle(
[](const c10::intrusive_ptr<OpaqueObject>& self) { // __getstate__
// Since we cannot directly return the py::object due to
// CustomClassHolder's signature limitations, we will have to
// serialize it directly here. We also can't return py::bytes so
// need to encode it into a string.
py::module_ pickle = py::module_::import("pickle");
py::module_ base64 = py::module_::import("base64");
py::bytes pickled_payload =
pickle.attr("dumps")(self->getPayload());
py::bytes encoded_payload =
base64.attr("b64encode")(pickled_payload);
return std::string(encoded_payload);
},
[](const std::string& state) { // __setstate__
py::module_ pickle = py::module_::import("pickle");
py::module_ base64 = py::module_::import("base64");
py::bytes state_bytes(state);
py::bytes decoded_payload = base64.attr("b64decode")(state_bytes);
py::object restored_payload =
pickle.attr("loads")(decoded_payload);
return c10::make_intrusive<OpaqueObject>(restored_payload);
})
.def(
"__obj_flatten__",
[](const c10::intrusive_ptr<OpaqueObject>& self) {
throw std::runtime_error(
"Unable to implement __obj_flatten__ for opaque objects.");
});
} // namespace torch::jit