Skip to content

Commit 6007373

Browse files
fix: bindings unit tests for nanobind (NVIDIA#6221)
Signed-off-by: Linda-Stadter <[email protected]>
1 parent 04f2d4b commit 6007373

File tree

10 files changed

+157
-214
lines changed

10 files changed

+157
-214
lines changed

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ void initBindings(nb::module_& m)
7979
}
8080
});
8181

82-
PybindUtils::bindSet<tb::ReqIdsSet>(m, "ReqIdsSet");
82+
NanobindUtils::bindSet<tb::ReqIdsSet>(m, "ReqIdsSet");
8383

8484
nb::enum_<tb::LlmRequestType>(m, "LlmRequestType")
8585
.value("LLMREQUEST_TYPE_CONTEXT_AND_GENERATION", tb::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION)

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32;
4848
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
4949
using VecTokens = std::vector<TokenIdType>;
5050
using CudaStreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
51+
using CacheBlockIds = std::vector<std::vector<SizeType32>>;
52+
53+
NB_MAKE_OPAQUE(CacheBlockIds);
5154

5255
namespace
5356
{
@@ -424,7 +427,15 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
424427
.def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds)
425428
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents);
426429

427-
nb::bind_vector<std::vector<std::vector<SizeType32>>>(m, "CacheBlockIds");
430+
nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
431+
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })
432+
.def("__setstate__",
433+
[](CacheBlockIds& self, nb::tuple const& t)
434+
{
435+
if (t.size() != 1)
436+
throw std::runtime_error("Invalid state!");
437+
new (&self) CacheBlockIds(nb::cast<std::vector<std::vector<SizeType32>>>(t[0]));
438+
});
428439

429440
nb::enum_<tbk::CacheType>(m, "CacheType")
430441
.value("SELF", tbk::CacheType::kSELF)

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,12 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
359359
config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP,
360360
config.beamWidthArray);
361361
};
362-
auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) -> tr::SamplingConfig
362+
auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t)
363363
{
364-
assert(t.size() == 19);
364+
if (t.size() != 19)
365+
{
366+
throw std::runtime_error("Invalid SamplingConfig state!");
367+
}
365368

366369
tr::SamplingConfig config;
367370
config.beamWidth = nb::cast<SizeType32>(t[0]);
@@ -384,7 +387,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
384387
config.minP = nb::cast<OptVec<float>>(t[17]);
385388
config.beamWidthArray = nb::cast<OptVec<std::vector<SizeType32>>>(t[18]);
386389

387-
return config;
390+
new (&self) tr::SamplingConfig(config);
388391
};
389392

390393
nb::class_<tr::SamplingConfig>(m, "SamplingConfig")

cpp/tensorrt_llm/nanobind/common/bindTypes.h

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,11 @@
2121
#include <nanobind/nanobind.h>
2222
#include <nanobind/stl/string.h>
2323

24-
namespace PybindUtils
24+
namespace NanobindUtils
2525
{
2626

2727
namespace nb = nanobind;
2828

29-
template <typename T>
30-
void bindList(nb::module_& m, std::string const& name)
31-
{
32-
nb::class_<T>(m, name.c_str())
33-
.def(nb::init<>())
34-
.def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); })
35-
.def("pop_back", [](T& lst) { lst.pop_back(); })
36-
.def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); })
37-
.def("pop_front", [](T& lst) { lst.pop_front(); })
38-
.def("__len__", [](T const& lst) { return lst.size(); })
39-
.def(
40-
"__iter__", [](T& lst) { return nb::make_iterator(nb::type<T>(), "iterator", lst.begin(), lst.end()); },
41-
nb::keep_alive<0, 1>())
42-
.def("__getitem__",
43-
[](T const& lst, size_t index)
44-
{
45-
if (index >= lst.size())
46-
throw nb::index_error();
47-
auto it = lst.begin();
48-
std::advance(it, index);
49-
return *it;
50-
})
51-
.def("__setitem__",
52-
[](T& lst, size_t index, const typename T::value_type& value)
53-
{
54-
if (index >= lst.size())
55-
throw nb::index_error();
56-
auto it = lst.begin();
57-
std::advance(it, index);
58-
*it = value;
59-
});
60-
}
61-
6229
template <typename T>
6330
void bindSet(nb::module_& m, std::string const& name)
6431
{
@@ -93,8 +60,8 @@ void bindSet(nb::module_& m, std::string const& name)
9360
{
9461
s.insert(item);
9562
}
96-
return s;
63+
new (&v) T(s);
9764
});
9865
}
9966

100-
} // namespace PybindUtils
67+
} // namespace NanobindUtils

cpp/tensorrt_llm/nanobind/common/customCasters.h

Lines changed: 35 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <torch/csrc/autograd/variable.h>
3939
#include <torch/extension.h>
4040
#include <torch/torch.h>
41+
#include <vector>
4142

4243
// Pybind requires to have a central include in order for type casters to work.
4344
// Opaque bindings add a type caster, so they have the same requirement.
@@ -48,7 +49,6 @@ NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet)
4849
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::batch_manager::SlotDecoderBuffers>)
4950
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::decoder_batch::Request>)
5051
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::SamplingConfig>)
51-
NB_MAKE_OPAQUE(std::vector<std::vector<tensorrt_llm::runtime::SizeType32>>)
5252

5353
namespace nb = nanobind;
5454

@@ -128,70 +128,6 @@ struct type_caster<tensorrt_llm::common::OptionalRef<T>>
128128
}
129129
};
130130

131-
template <typename T>
132-
struct PathCaster
133-
{
134-
135-
private:
136-
static PyObject* unicode_from_fs_native(std::string const& w)
137-
{
138-
return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size()));
139-
}
140-
141-
static PyObject* unicode_from_fs_native(std::wstring const& w)
142-
{
143-
return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size()));
144-
}
145-
146-
public:
147-
static handle from_cpp(T const& path, rv_policy, cleanup_list* cleanup)
148-
{
149-
if (auto py_str = unicode_from_fs_native(path.native()))
150-
{
151-
return module_::import_("pathlib").attr("Path")(steal<object>(py_str), cleanup).release();
152-
}
153-
return nullptr;
154-
}
155-
156-
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup)
157-
{
158-
PyObject* native = nullptr;
159-
if constexpr (std::is_same_v<typename T::value_type, char>)
160-
{
161-
if (PyUnicode_FSConverter(src.ptr(), &native) != 0)
162-
{
163-
if (auto* c_str = PyBytes_AsString(native))
164-
{
165-
// AsString returns a pointer to the internal buffer, which
166-
// must not be free'd.
167-
value = c_str;
168-
}
169-
}
170-
}
171-
else if constexpr (std::is_same_v<typename T::value_type, wchar_t>)
172-
{
173-
if (PyUnicode_FSDecoder(src.ptr(), &native) != 0)
174-
{
175-
if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr))
176-
{
177-
// AsWideCharString returns a new string that must be free'd.
178-
value = c_str; // Copies the string.
179-
PyMem_Free(c_str);
180-
}
181-
}
182-
}
183-
Py_XDECREF(native);
184-
if (PyErr_Occurred())
185-
{
186-
PyErr_Clear();
187-
return false;
188-
}
189-
return true;
190-
}
191-
192-
NB_TYPE_CASTER(T, const_name("os.PathLike"));
193-
};
194-
195131
template <>
196132
class type_caster<tensorrt_llm::executor::StreamPtr>
197133
{
@@ -311,34 +247,45 @@ struct type_caster<at::Tensor>
311247

312248
bool from_python(nb::handle src, uint8_t, cleanup_list*) noexcept
313249
{
314-
nb::object capsule = nb::getattr(src, "__dlpack__")();
315-
DLManagedTensor* dl_managed = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule.ptr(), "dltensor"));
316-
PyCapsule_SetDestructor(capsule.ptr(), nullptr);
317-
value = at::fromDLPack(dl_managed).alias();
318-
return true;
250+
PyObject* obj = src.ptr();
251+
if (THPVariable_Check(obj))
252+
{
253+
value = THPVariable_Unpack(obj);
254+
return true;
255+
}
256+
return false;
319257
}
320258

321-
static handle from_cpp(at::Tensor tensor, rv_policy, cleanup_list*) noexcept
259+
static handle from_cpp(at::Tensor src, rv_policy, cleanup_list*) noexcept
322260
{
323-
DLManagedTensor* dl_managed = at::toDLPack(tensor);
324-
if (!dl_managed)
325-
return nullptr;
326-
327-
nanobind::object capsule = nb::steal(PyCapsule_New(dl_managed, "dltensor",
328-
[](PyObject* obj)
329-
{
330-
DLManagedTensor* dl = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(obj, "dltensor"));
331-
dl->deleter(dl);
332-
}));
333-
if (!capsule.is_valid())
261+
return THPVariable_Wrap(src);
262+
}
263+
};
264+
265+
template <typename T>
266+
struct type_caster<std::vector<std::reference_wrapper<T const>>>
267+
{
268+
using VectorType = std::vector<std::reference_wrapper<T const>>;
269+
270+
NB_TYPE_CASTER(VectorType, const_name("List[") + make_caster<T>::Name + const_name("]"));
271+
272+
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept
273+
{
274+
// Not needed for our use case since we only convert C++ to Python
275+
return false;
276+
}
277+
278+
static handle from_cpp(VectorType const& src, rv_policy policy, cleanup_list* cleanup) noexcept
279+
{
280+
281+
std::vector<T> result;
282+
result.reserve(src.size());
283+
for (auto const& ref : src)
334284
{
335-
dl_managed->deleter(dl_managed);
336-
return nullptr;
285+
result.push_back(ref.get());
337286
}
338-
nanobind::module_ torch = nanobind::module_::import_("torch");
339-
nanobind::object result = torch.attr("from_dlpack")(capsule);
340-
capsule.release();
341-
return result.release();
287+
288+
return make_caster<std::vector<T>>::from_cpp(result, policy, cleanup);
342289
}
343290
};
344291
} // namespace detail

cpp/tensorrt_llm/nanobind/executor/executor.cpp

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,58 +52,37 @@ struct dtype_traits<half>
5252

5353
namespace
5454
{
55-
// todo: Properly support FP8 and BF16 and verify functionality
56-
tle::Tensor numpyToTensor(nb::ndarray<nb::numpy> const& array)
55+
tle::Tensor numpyToTensor(nb::object const& object)
5756
{
58-
auto npDtype = array.dtype();
59-
char kind = '\0';
60-
switch (npDtype.code)
61-
{
62-
case static_cast<uint8_t>(nb::dlpack::dtype_code::Int):
63-
kind = 'i'; // signed integer
64-
break;
65-
case static_cast<uint8_t>(nb::dlpack::dtype_code::UInt):
66-
kind = 'u'; // unsigned integer
67-
break;
68-
case static_cast<uint8_t>(nb::dlpack::dtype_code::Float):
69-
kind = 'f'; // floating point
70-
break;
71-
case static_cast<uint8_t>(nb::dlpack::dtype_code::Bfloat):
72-
kind = 'f'; // brain floating point (treat as float kind)
73-
break;
74-
case static_cast<uint8_t>(nb::dlpack::dtype_code::Complex):
75-
kind = 'c'; // complex
76-
break;
77-
default:
78-
kind = 'V'; // void/other
79-
break;
80-
}
57+
std::string dtype_name = nb::cast<std::string>(object.attr("dtype").attr("name"));
58+
nb::object metadata = object.attr("dtype").attr("metadata");
59+
8160
tle::DataType dtype;
82-
if (npDtype == nb::dtype<half>())
61+
if (dtype_name == "float16")
8362
{
8463
dtype = tle::DataType::kFP16;
8564
}
86-
else if (npDtype == nb::dtype<float>())
65+
else if (dtype_name == "float32")
8766
{
8867
dtype = tle::DataType::kFP32;
8968
}
90-
else if (npDtype == nb::dtype<int8_t>())
69+
else if (dtype_name == "int8")
9170
{
9271
dtype = tle::DataType::kINT8;
9372
}
94-
else if (npDtype == nb::dtype<int32_t>())
73+
else if (dtype_name == "int32")
9574
{
9675
dtype = tle::DataType::kINT32;
9776
}
98-
else if (npDtype == nb::dtype<int64_t>())
77+
else if (dtype_name == "int64")
9978
{
10079
dtype = tle::DataType::kINT64;
10180
}
102-
else if (kind == 'V' && array.itemsize() == 1)
81+
else if (dtype_name == "void8" && !metadata.is_none() && nb::cast<std::string>(metadata["dtype"]) == "float8")
10382
{
10483
dtype = tle::DataType::kFP8;
10584
}
106-
else if (kind == 'V' && array.itemsize() == 2)
85+
else if (dtype_name == "void16" && !metadata.is_none() && nb::cast<std::string>(metadata["dtype"]) == "bfloat16")
10786
{
10887
dtype = tle::DataType::kBF16;
10988
}
@@ -112,16 +91,21 @@ tle::Tensor numpyToTensor(nb::ndarray<nb::numpy> const& array)
11291
TLLM_THROW("Unsupported numpy dtype.");
11392
}
11493

115-
// todo: improve the following code
94+
nb::object array_interface = object.attr("__array_interface__");
95+
nb::object shape_obj = array_interface["shape"];
11696
std::vector<int64_t> dims;
117-
dims.reserve(array.ndim());
118-
for (size_t i = 0; i < array.ndim(); ++i)
97+
dims.reserve(nb::len(shape_obj));
98+
99+
for (size_t i = 0; i < nb::len(shape_obj); ++i)
119100
{
120-
dims.push_back(static_cast<int64_t>(array.shape(i)));
101+
dims.push_back(nb::cast<int64_t>(shape_obj[i]));
121102
}
122-
tle::Shape shape(dims.data(), dims.size());
123103

124-
return tle::Tensor::of(dtype, const_cast<void*>(array.data()), shape);
104+
nb::object data_obj = array_interface["data"];
105+
uintptr_t addr = nb::cast<uintptr_t>(data_obj[0]);
106+
void* data_ptr = reinterpret_cast<void*>(addr);
107+
tle::Shape shape(dims.data(), dims.size());
108+
return tle::Tensor::of(dtype, data_ptr, shape);
125109
}
126110

127111
} // namespace
@@ -153,8 +137,8 @@ Executor::Executor(nb::bytes const& engineBuffer, std::string const& jsonConfigS
153137
for (auto const& [rawName, rawArray] : managedWeights.value())
154138
{
155139
std::string name = nb::cast<std::string>(rawName);
156-
nb::ndarray<nb::numpy> array = nb::cast<nb::ndarray<nb::numpy>>(rawArray);
157-
managedWeightsMap->emplace(name, numpyToTensor(array));
140+
nb::object array_obj = nb::cast<nb::object>(rawArray);
141+
managedWeightsMap->emplace(name, numpyToTensor(array_obj));
158142
}
159143
}
160144
mExecutor = std::make_unique<tle::Executor>(

0 commit comments

Comments
 (0)