Skip to content

Commit 981dd71

Browse files
swolchokpytorchmergebot
authored andcommitted
Refactor: extract OperatorArgsKwargsView from parseIValuesToPyArgsKwargs (pytorch#166368)
Intended to make it easier to reuse this logic for processing operator arguments as IValues in following PR(s). Testing: python test/test_python_dispatch.py (broke during development, seems to work now) Pull Request resolved: pytorch#166368 Approved by: https://github.com/albanD
1 parent d31599f commit 981dd71

File tree

1 file changed

+151
-42
lines changed

1 file changed

+151
-42
lines changed

torch/csrc/autograd/python_variable.cpp

Lines changed: 151 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,60 +51,165 @@ using namespace at;
5151
using namespace torch;
5252
using namespace torch::autograd;
5353

54-
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
55-
const c10::OperatorHandle& op,
56-
const std::vector<c10::IValue>& arguments) {
57-
TORCH_CHECK(
58-
PyGILState_Check(),
59-
"GIL must be held before you call parseIValuesToPyArgsKwargs");
60-
const auto& schema = op.schema();
61-
py::dict kwargs;
54+
namespace {
55+
class OperatorArgsKwargsView {
56+
public:
57+
OperatorArgsKwargsView(
58+
const c10::OperatorHandle& op,
59+
const std::vector<c10::IValue>& arguments);
60+
using args_iterator = const c10::IValue*;
61+
62+
args_iterator args_begin() const {
63+
return arguments_.data();
64+
}
65+
66+
args_iterator args_end() const {
67+
return arguments_.data() + positional_default_start_;
68+
}
69+
70+
auto num_positional_args() const {
71+
return positional_default_start_;
72+
}
73+
74+
auto kwarg_start_index() const {
75+
return first_non_default_kwarg_;
76+
}
77+
78+
struct kwargs_iterator {
79+
kwargs_iterator() = default;
80+
kwargs_iterator(const OperatorArgsKwargsView* parent, size_t current)
81+
: parent_(parent), current_(current) {}
82+
83+
kwargs_iterator(const kwargs_iterator&) = default;
84+
kwargs_iterator& operator=(const kwargs_iterator&) = default;
85+
86+
kwargs_iterator& operator++() {
87+
do {
88+
current_++;
89+
} while (current_ < parent_->arguments_.size() &&
90+
parent_->is_default(current_));
91+
return *this;
92+
}
93+
94+
kwargs_iterator operator++(int) {
95+
auto copy = *this;
96+
++(*this);
97+
return copy;
98+
}
99+
100+
const c10::IValue& operator*() const {
101+
return parent_->arguments_[current_];
102+
}
103+
104+
const c10::IValue* operator->() const {
105+
return &operator*();
106+
}
107+
108+
int64_t underlying_index() const {
109+
return current_;
110+
}
111+
112+
bool operator==(const kwargs_iterator& rhs) const {
113+
return parent_ == rhs.parent_ && current_ == rhs.current_;
114+
}
115+
116+
bool operator!=(const kwargs_iterator& rhs) {
117+
return !(*this == rhs);
118+
}
119+
120+
private:
121+
const OperatorArgsKwargsView* parent_ = nullptr;
122+
size_t current_ = 0;
123+
};
124+
125+
kwargs_iterator kwargs_begin() const {
126+
return kwargs_iterator(this, first_non_default_kwarg_);
127+
}
128+
129+
kwargs_iterator kwargs_end() const {
130+
return kwargs_iterator(this, arguments_.size());
131+
}
132+
133+
private:
134+
bool is_default(size_t idx) const {
135+
const auto& arg = op_.schema().arguments()[idx];
136+
if (!arg.default_value().has_value()) {
137+
return false;
138+
}
139+
const auto& default_ivalue = *arg.default_value();
140+
const auto& ivalue = arguments_[idx];
141+
if (default_ivalue != ivalue) {
142+
return false;
143+
}
144+
return true;
145+
}
146+
147+
const c10::OperatorHandle& op_;
148+
c10::ArrayRef<c10::IValue> arguments_;
62149
// About all the pointers:
63150
//
64151
// f(int x, int y = 0, *, int z = 0)
65152
// ^- arguments.size()
66153
// ^- kwarg_only_start
67154
// ^- positional_default_start
68155
// ^- 0
156+
int64_t positional_default_start_;
157+
int64_t first_non_default_kwarg_;
158+
};
69159

160+
OperatorArgsKwargsView::OperatorArgsKwargsView(
161+
const c10::OperatorHandle& op,
162+
const std::vector<c10::IValue>& arguments)
163+
: op_(op), arguments_(arguments) {
70164
// Find the split point between kwarg-only and regular. Since most functions
71165
// don't have kwarg-only arguments, it is more efficient to scan from the
72166
// right (but ideally, this would just be precomputed in FunctionSchema
73167
// itself). (NB: minus one in the loop is because we're testing if the
74168
// *next* argument is kwarg-only before we advance the starting index)
75-
int64_t kwarg_only_start = static_cast<int64_t>(arguments.size());
169+
const int64_t signed_arguments_size = static_cast<int64_t>(arguments.size());
170+
int64_t kwarg_only_start = signed_arguments_size;
76171
for (; kwarg_only_start > 0; kwarg_only_start--) {
77-
const auto& arg = schema.arguments()[kwarg_only_start - 1];
172+
const auto& arg = op.schema().arguments()[kwarg_only_start - 1];
78173
if (!arg.kwarg_only()) {
79174
break;
80175
}
81176
}
82177

83178
// Find the first positional argument that isn't defaulted
84-
auto is_default = [&](size_t idx) -> bool {
85-
const auto& arg = schema.arguments()[idx];
86-
if (!arg.default_value().has_value()) {
87-
return false;
88-
}
89-
const auto& default_ivalue = *arg.default_value();
90-
const auto& ivalue = arguments[idx];
91-
if (default_ivalue != ivalue) {
92-
return false;
179+
positional_default_start_ = kwarg_only_start;
180+
for (; positional_default_start_ > 0; positional_default_start_--) {
181+
if (!is_default(positional_default_start_ - 1)) {
182+
break;
93183
}
94-
return true;
95-
};
184+
}
96185

97-
int64_t positional_default_start = kwarg_only_start;
98-
for (; positional_default_start > 0; positional_default_start--) {
99-
if (!is_default(positional_default_start - 1)) {
186+
// kwargs_iterator will skip default kwargs when incremented, but we
187+
// need to skip any initial run of default kwargs ourselves.
188+
first_non_default_kwarg_ = kwarg_only_start;
189+
for (; first_non_default_kwarg_ < signed_arguments_size;
190+
++first_non_default_kwarg_) {
191+
if (!is_default(first_non_default_kwarg_)) {
100192
break;
101193
}
102194
}
195+
}
196+
} // namespace
103197

104-
auto args =
105-
py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
198+
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
199+
const c10::OperatorHandle& op,
200+
const std::vector<c10::IValue>& arguments) {
201+
TORCH_CHECK(
202+
PyGILState_Check(),
203+
"GIL must be held before you call parseIValuesToPyArgsKwargs");
204+
const auto& schema = op.schema();
205+
py::dict kwargs;
106206

107-
auto schemaAwareToPyObject = [&](size_t idx) -> py::object {
207+
OperatorArgsKwargsView args_kwargs(op, arguments);
208+
auto args = py::reinterpret_steal<py::object>(
209+
PyTuple_New(args_kwargs.num_positional_args()));
210+
211+
auto schemaAwareToPyObject =
212+
[&schema](size_t idx, const c10::IValue& argument) -> py::object {
108213
const auto& arg = schema.arguments()[idx];
109214
auto match = [&](c10::TypeKind kind) {
110215
const auto& t = arg.real_type();
@@ -116,38 +221,42 @@ std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
116221
}
117222
return false;
118223
};
119-
if (arguments[idx].isNone()) {
224+
if (argument.isNone()) {
120225
return py::none();
121226
} else if (match(c10::ScalarTypeType::Kind)) {
122-
auto* obj =
123-
getTHPDtype(static_cast<c10::ScalarType>(arguments[idx].toInt()));
227+
auto* obj = getTHPDtype(static_cast<c10::ScalarType>(argument.toInt()));
124228
return py::reinterpret_borrow<py::object>(
125229
reinterpret_cast<PyObject*>(obj));
126230
} else if (match(c10::LayoutType::Kind)) {
127-
auto* obj =
128-
getTHPLayout(static_cast<c10::Layout>(arguments[idx].toInt()));
231+
auto* obj = getTHPLayout(static_cast<c10::Layout>(argument.toInt()));
129232
return py::reinterpret_borrow<py::object>(
130233
reinterpret_cast<PyObject*>(obj));
131234
} else if (match(c10::MemoryFormatType::Kind)) {
132-
return py::cast(static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
235+
return py::cast(static_cast<c10::MemoryFormat>(argument.toInt()));
133236
} else {
134-
return torch::jit::toPyObject(arguments[idx]);
237+
return torch::jit::toPyObject(argument);
135238
}
136239
};
137240

138241
// Populate positional arguments
139-
for (const auto idx : c10::irange(positional_default_start)) {
242+
size_t idx = 0;
243+
for (auto argument_it = args_kwargs.args_begin();
244+
argument_it != args_kwargs.args_end();
245+
++argument_it) {
140246
PyTuple_SET_ITEM(
141-
args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr());
247+
args.ptr(),
248+
idx,
249+
schemaAwareToPyObject(idx, *argument_it).release().ptr());
250+
idx++;
142251
}
143252

144253
// Populate keyword arguments
145-
for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) {
146-
// But don't populate default keyword arguments
147-
if (is_default(idx))
148-
continue;
149-
const auto& arg = schema.arguments()[idx];
150-
kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx);
254+
for (auto argument_it = args_kwargs.kwargs_begin();
255+
argument_it != args_kwargs.kwargs_end();
256+
++argument_it) {
257+
const auto& arg = schema.arguments()[argument_it.underlying_index()];
258+
kwargs[py::cast(arg.name())] =
259+
schemaAwareToPyObject(argument_it.underlying_index(), *argument_it);
151260
}
152261
return std::make_pair(std::move(args), std::move(kwargs));
153262
}

0 commit comments

Comments
 (0)