3838#include < torch/csrc/autograd/variable.h>
3939#include < torch/extension.h>
4040#include < torch/torch.h>
41+ #include < vector>
4142
4243// Pybind requires to have a central include in order for type casters to work.
4344// Opaque bindings add a type caster, so they have the same requirement.
@@ -48,7 +49,6 @@ NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet)
4849NB_MAKE_OPAQUE(std::vector<tensorrt_llm::batch_manager::SlotDecoderBuffers>)
4950NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::decoder_batch::Request>)
5051NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::SamplingConfig>)
51- NB_MAKE_OPAQUE(std::vector<std::vector<tensorrt_llm::runtime::SizeType32>>)
5252
5353namespace nb = nanobind;
5454
@@ -128,70 +128,6 @@ struct type_caster<tensorrt_llm::common::OptionalRef<T>>
128128 }
129129};
130130
131- template <typename T>
132- struct PathCaster
133- {
134-
135- private:
136- static PyObject* unicode_from_fs_native (std::string const & w)
137- {
138- return PyUnicode_DecodeFSDefaultAndSize (w.c_str (), ssize_t (w.size ()));
139- }
140-
141- static PyObject* unicode_from_fs_native (std::wstring const & w)
142- {
143- return PyUnicode_FromWideChar (w.c_str (), ssize_t (w.size ()));
144- }
145-
146- public:
147- static handle from_cpp (T const & path, rv_policy, cleanup_list* cleanup)
148- {
149- if (auto py_str = unicode_from_fs_native (path.native ()))
150- {
151- return module_::import_ (" pathlib" ).attr (" Path" )(steal<object>(py_str), cleanup).release ();
152- }
153- return nullptr ;
154- }
155-
156- bool from_python (handle src, uint8_t flags, cleanup_list* cleanup)
157- {
158- PyObject* native = nullptr ;
159- if constexpr (std::is_same_v<typename T::value_type, char >)
160- {
161- if (PyUnicode_FSConverter (src.ptr (), &native) != 0 )
162- {
163- if (auto * c_str = PyBytes_AsString (native))
164- {
165- // AsString returns a pointer to the internal buffer, which
166- // must not be free'd.
167- value = c_str;
168- }
169- }
170- }
171- else if constexpr (std::is_same_v<typename T::value_type, wchar_t >)
172- {
173- if (PyUnicode_FSDecoder (src.ptr (), &native) != 0 )
174- {
175- if (auto * c_str = PyUnicode_AsWideCharString (native, nullptr ))
176- {
177- // AsWideCharString returns a new string that must be free'd.
178- value = c_str; // Copies the string.
179- PyMem_Free (c_str);
180- }
181- }
182- }
183- Py_XDECREF (native);
184- if (PyErr_Occurred ())
185- {
186- PyErr_Clear ();
187- return false ;
188- }
189- return true ;
190- }
191-
192- NB_TYPE_CASTER (T, const_name(" os.PathLike" ));
193- };
194-
195131template <>
196132class type_caster <tensorrt_llm::executor::StreamPtr>
197133{
@@ -311,34 +247,45 @@ struct type_caster<at::Tensor>
311247
312248 bool from_python (nb::handle src, uint8_t , cleanup_list*) noexcept
313249 {
314- nb::object capsule = nb::getattr (src, " __dlpack__" )();
315- DLManagedTensor* dl_managed = static_cast <DLManagedTensor*>(PyCapsule_GetPointer (capsule.ptr (), " dltensor" ));
316- PyCapsule_SetDestructor (capsule.ptr (), nullptr );
317- value = at::fromDLPack (dl_managed).alias ();
318- return true ;
250+ PyObject* obj = src.ptr ();
251+ if (THPVariable_Check (obj))
252+ {
253+ value = THPVariable_Unpack (obj);
254+ return true ;
255+ }
256+ return false ;
319257 }
320258
321- static handle from_cpp (at::Tensor tensor , rv_policy, cleanup_list*) noexcept
259+ static handle from_cpp (at::Tensor src , rv_policy, cleanup_list*) noexcept
322260 {
323- DLManagedTensor* dl_managed = at::toDLPack (tensor);
324- if (!dl_managed)
325- return nullptr ;
326-
327- nanobind::object capsule = nb::steal (PyCapsule_New (dl_managed, " dltensor" ,
328- [](PyObject* obj)
329- {
330- DLManagedTensor* dl = static_cast <DLManagedTensor*>(PyCapsule_GetPointer (obj, " dltensor" ));
331- dl->deleter (dl);
332- }));
333- if (!capsule.is_valid ())
261+ return THPVariable_Wrap (src);
262+ }
263+ };
264+
265+ template <typename T>
266+ struct type_caster <std::vector<std::reference_wrapper<T const >>>
267+ {
268+ using VectorType = std::vector<std::reference_wrapper<T const >>;
269+
270+ NB_TYPE_CASTER (VectorType, const_name(" List[" ) + make_caster<T>::Name + const_name(" ]" ));
271+
272+ bool from_python (handle src, uint8_t flags, cleanup_list* cleanup) noexcept
273+ {
274+ // Not needed for our use case since we only convert C++ to Python
275+ return false ;
276+ }
277+
278+ static handle from_cpp (VectorType const & src, rv_policy policy, cleanup_list* cleanup) noexcept
279+ {
280+
281+ std::vector<T> result;
282+ result.reserve (src.size ());
283+ for (auto const & ref : src)
334284 {
335- dl_managed->deleter (dl_managed);
336- return nullptr ;
285+ result.push_back (ref.get ());
337286 }
338- nanobind::module_ torch = nanobind::module_::import_ (" torch" );
339- nanobind::object result = torch.attr (" from_dlpack" )(capsule);
340- capsule.release ();
341- return result.release ();
287+
288+ return make_caster<std::vector<T>>::from_cpp (result, policy, cleanup);
342289 }
343290};
344291} // namespace detail
0 commit comments