@@ -51,60 +51,165 @@ using namespace at;
5151using namespace torch ;
5252using namespace torch ::autograd;
5353
54- std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs (
55- const c10::OperatorHandle& op,
56- const std::vector<c10::IValue>& arguments) {
57- TORCH_CHECK (
58- PyGILState_Check (),
59- " GIL must be held before you call parseIValuesToPyArgsKwargs" );
60- const auto & schema = op.schema ();
61- py::dict kwargs;
54+ namespace {
55+ class OperatorArgsKwargsView {
56+ public:
57+ OperatorArgsKwargsView (
58+ const c10::OperatorHandle& op,
59+ const std::vector<c10::IValue>& arguments);
60+ using args_iterator = const c10::IValue*;
61+
62+ args_iterator args_begin () const {
63+ return arguments_.data ();
64+ }
65+
66+ args_iterator args_end () const {
67+ return arguments_.data () + positional_default_start_;
68+ }
69+
70+ auto num_positional_args () const {
71+ return positional_default_start_;
72+ }
73+
74+ auto kwarg_start_index () const {
75+ return first_non_default_kwarg_;
76+ }
77+
78+ struct kwargs_iterator {
79+ kwargs_iterator () = default ;
80+ kwargs_iterator (const OperatorArgsKwargsView* parent, size_t current)
81+ : parent_(parent), current_(current) {}
82+
83+ kwargs_iterator (const kwargs_iterator&) = default ;
84+ kwargs_iterator& operator =(const kwargs_iterator&) = default ;
85+
86+ kwargs_iterator& operator ++() {
87+ do {
88+ current_++;
89+ } while (current_ < parent_->arguments_ .size () &&
90+ parent_->is_default (current_));
91+ return *this ;
92+ }
93+
94+ kwargs_iterator operator ++(int ) {
95+ auto copy = *this ;
96+ ++(*this );
97+ return copy;
98+ }
99+
100+ const c10::IValue& operator *() const {
101+ return parent_->arguments_ [current_];
102+ }
103+
104+ const c10::IValue* operator ->() const {
105+ return &operator *();
106+ }
107+
108+ int64_t underlying_index () const {
109+ return current_;
110+ }
111+
112+ bool operator ==(const kwargs_iterator& rhs) const {
113+ return parent_ == rhs.parent_ && current_ == rhs.current_ ;
114+ }
115+
116+ bool operator !=(const kwargs_iterator& rhs) {
117+ return !(*this == rhs);
118+ }
119+
120+ private:
121+ const OperatorArgsKwargsView* parent_ = nullptr ;
122+ size_t current_ = 0 ;
123+ };
124+
125+ kwargs_iterator kwargs_begin () const {
126+ return kwargs_iterator (this , first_non_default_kwarg_);
127+ }
128+
129+ kwargs_iterator kwargs_end () const {
130+ return kwargs_iterator (this , arguments_.size ());
131+ }
132+
133+ private:
134+ bool is_default (size_t idx) const {
135+ const auto & arg = op_.schema ().arguments ()[idx];
136+ if (!arg.default_value ().has_value ()) {
137+ return false ;
138+ }
139+ const auto & default_ivalue = *arg.default_value ();
140+ const auto & ivalue = arguments_[idx];
141+ if (default_ivalue != ivalue) {
142+ return false ;
143+ }
144+ return true ;
145+ }
146+
147+ const c10::OperatorHandle& op_;
148+ c10::ArrayRef<c10::IValue> arguments_;
62149 // About all the pointers:
63150 //
64151 // f(int x, int y = 0, *, int z = 0)
65152 // ^- arguments.size()
66153 // ^- kwarg_only_start
67154 // ^- positional_default_start
68155 // ^- 0
156+ int64_t positional_default_start_;
157+ int64_t first_non_default_kwarg_;
158+ };
69159
160+ OperatorArgsKwargsView::OperatorArgsKwargsView (
161+ const c10::OperatorHandle& op,
162+ const std::vector<c10::IValue>& arguments)
163+ : op_(op), arguments_(arguments) {
70164 // Find the split point between kwarg-only and regular. Since most functions
71165 // don't have kwarg-only arguments, it is more efficient to scan from the
72166 // right (but ideally, this would just be precomputed in FunctionSchema
73167 // itself). (NB: minus one in the loop is because we're testing if the
74168 // *next* argument is kwarg-only before we advance the starting index)
75- int64_t kwarg_only_start = static_cast <int64_t >(arguments.size ());
169+ const int64_t signed_arguments_size = static_cast <int64_t >(arguments.size ());
170+ int64_t kwarg_only_start = signed_arguments_size;
76171 for (; kwarg_only_start > 0 ; kwarg_only_start--) {
77- const auto & arg = schema.arguments ()[kwarg_only_start - 1 ];
172+ const auto & arg = op. schema () .arguments ()[kwarg_only_start - 1 ];
78173 if (!arg.kwarg_only ()) {
79174 break ;
80175 }
81176 }
82177
83178 // Find the first positional argument that isn't defaulted
84- auto is_default = [&](size_t idx) -> bool {
85- const auto & arg = schema.arguments ()[idx];
86- if (!arg.default_value ().has_value ()) {
87- return false ;
88- }
89- const auto & default_ivalue = *arg.default_value ();
90- const auto & ivalue = arguments[idx];
91- if (default_ivalue != ivalue) {
92- return false ;
179+ positional_default_start_ = kwarg_only_start;
180+ for (; positional_default_start_ > 0 ; positional_default_start_--) {
181+ if (!is_default (positional_default_start_ - 1 )) {
182+ break ;
93183 }
94- return true ;
95- };
184+ }
96185
97- int64_t positional_default_start = kwarg_only_start;
98- for (; positional_default_start > 0 ; positional_default_start--) {
99- if (!is_default (positional_default_start - 1 )) {
186+ // kwargs_iterator will skip default kwargs when incremented, but we
187+ // need to skip any initial run of default kwargs ourselves.
188+ first_non_default_kwarg_ = kwarg_only_start;
189+ for (; first_non_default_kwarg_ < signed_arguments_size;
190+ ++first_non_default_kwarg_) {
191+ if (!is_default (first_non_default_kwarg_)) {
100192 break ;
101193 }
102194 }
195+ }
196+ } // namespace
103197
104- auto args =
105- py::reinterpret_steal<py::object>(PyTuple_New (positional_default_start));
198+ std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs (
199+ const c10::OperatorHandle& op,
200+ const std::vector<c10::IValue>& arguments) {
201+ TORCH_CHECK (
202+ PyGILState_Check (),
203+ " GIL must be held before you call parseIValuesToPyArgsKwargs" );
204+ const auto & schema = op.schema ();
205+ py::dict kwargs;
106206
107- auto schemaAwareToPyObject = [&](size_t idx) -> py::object {
207+ OperatorArgsKwargsView args_kwargs (op, arguments);
208+ auto args = py::reinterpret_steal<py::object>(
209+ PyTuple_New (args_kwargs.num_positional_args ()));
210+
211+ auto schemaAwareToPyObject =
212+ [&schema](size_t idx, const c10::IValue& argument) -> py::object {
108213 const auto & arg = schema.arguments ()[idx];
109214 auto match = [&](c10::TypeKind kind) {
110215 const auto & t = arg.real_type ();
@@ -116,38 +221,42 @@ std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
116221 }
117222 return false ;
118223 };
119- if (arguments[idx] .isNone ()) {
224+ if (argument .isNone ()) {
120225 return py::none ();
121226 } else if (match (c10::ScalarTypeType::Kind)) {
122- auto * obj =
123- getTHPDtype (static_cast <c10::ScalarType>(arguments[idx].toInt ()));
227+ auto * obj = getTHPDtype (static_cast <c10::ScalarType>(argument.toInt ()));
124228 return py::reinterpret_borrow<py::object>(
125229 reinterpret_cast <PyObject*>(obj));
126230 } else if (match (c10::LayoutType::Kind)) {
127- auto * obj =
128- getTHPLayout (static_cast <c10::Layout>(arguments[idx].toInt ()));
231+ auto * obj = getTHPLayout (static_cast <c10::Layout>(argument.toInt ()));
129232 return py::reinterpret_borrow<py::object>(
130233 reinterpret_cast <PyObject*>(obj));
131234 } else if (match (c10::MemoryFormatType::Kind)) {
132- return py::cast (static_cast <c10::MemoryFormat>(arguments[idx] .toInt ()));
235+ return py::cast (static_cast <c10::MemoryFormat>(argument .toInt ()));
133236 } else {
134- return torch::jit::toPyObject (arguments[idx] );
237+ return torch::jit::toPyObject (argument );
135238 }
136239 };
137240
138241 // Populate positional arguments
139- for (const auto idx : c10::irange (positional_default_start)) {
242+ size_t idx = 0 ;
243+ for (auto argument_it = args_kwargs.args_begin ();
244+ argument_it != args_kwargs.args_end ();
245+ ++argument_it) {
140246 PyTuple_SET_ITEM (
141- args.ptr (), idx, schemaAwareToPyObject (idx).release ().ptr ());
247+ args.ptr (),
248+ idx,
249+ schemaAwareToPyObject (idx, *argument_it).release ().ptr ());
250+ idx++;
142251 }
143252
144253 // Populate keyword arguments
145- for (const auto idx : c10::irange (kwarg_only_start, arguments. size ())) {
146- // But don't populate default keyword arguments
147- if ( is_default (idx))
148- continue ;
149- const auto & arg = schema. arguments ()[idx];
150- kwargs[ py::cast (arg. name ())] = schemaAwareToPyObject (idx );
254+ for (auto argument_it = args_kwargs. kwargs_begin ();
255+ argument_it != args_kwargs. kwargs_end ();
256+ ++argument_it) {
257+ const auto & arg = schema. arguments ()[argument_it. underlying_index ()] ;
258+ kwargs[ py::cast ( arg. name ())] =
259+ schemaAwareToPyObject (argument_it. underlying_index (), *argument_it );
151260 }
152261 return std::make_pair (std::move (args), std::move (kwargs));
153262}
0 commit comments