@@ -219,15 +219,42 @@ PYBIND11_MODULE(_llm_runner, m) {
219219
220220 // Bind Image class
221221 py::class_<Image>(m, " Image" )
222- .def (py::init<>())
223- .def_readwrite (" data" , &Image::data)
224- .def_readwrite (" width" , &Image::width)
225- .def_readwrite (" height" , &Image::height)
226- .def_readwrite (" channels" , &Image::channels)
222+ .def (
223+ py::init<std::vector<uint8_t >&&, int32_t , int32_t , int32_t >(),
224+ py::arg (" data" ),
225+ py::arg (" width" ),
226+ py::arg (" height" ),
227+ py::arg (" channels" ))
228+ .def (
229+ py::init<std::vector<float >&&, int32_t , int32_t , int32_t >(),
230+ py::arg (" data" ),
231+ py::arg (" width" ),
232+ py::arg (" height" ),
233+ py::arg (" channels" ))
234+ .def (" is_uint8" , &Image::is_uint8)
235+ .def (" is_float" , &Image::is_float)
236+ .def_property_readonly (" width" , &Image::width)
237+ .def_property_readonly (" height" , &Image::height)
238+ .def_property_readonly (" channels" , &Image::channels)
239+ .def_property_readonly (
240+ " uint8_data" ,
241+ static_cast <const std::vector<uint8_t >& (Image::*)() const &>(
242+ &Image::get_uint8_data))
243+ .def_property_readonly (
244+ " float_data" ,
245+ static_cast <const std::vector<float >& (Image::*)() const &>(
246+ &Image::get_float_data))
227247 .def (" __repr__" , [](const Image& img) {
228- return " <Image height=" + std::to_string (img.height ) +
229- " width=" + std::to_string (img.width ) +
230- " channels=" + std::to_string (img.channels ) + " >" ;
248+ std::string dtype = " unknown" ;
249+ if (img.is_uint8 ()) {
250+ dtype = " uint8" ;
251+ } else if (img.is_float ()) {
252+ dtype = " float32" ;
253+ }
254+ return " <Image height=" + std::to_string (img.height ()) +
255+ " width=" + std::to_string (img.width ()) +
256+ " channels=" + std::to_string (img.channels ()) + " dtype=" + dtype +
257+ " >" ;
231258 });
232259
233260 // Bind MultimodalInput
@@ -281,7 +308,6 @@ PYBIND11_MODULE(_llm_runner, m) {
281308 image_tensor = image_tensor.squeeze (0 );
282309 }
283310
284-
285311 if (image_tensor.dim () != 3 ) {
286312 throw std::runtime_error (
287313 " Image tensor must be 3-dimensional (H, W, C) or 4-dimensional (1, H, W, C)" );
@@ -322,12 +348,11 @@ PYBIND11_MODULE(_llm_runner, m) {
322348 uint8_t * data = image_tensor.data_ptr <uint8_t >();
323349 std::vector<uint8_t > image_data (data, data + image_tensor.numel ());
324350
325- Image image;
326- image.data = std::move (image_data);
327- image.width = static_cast <int32_t >(width);
328- image.height = static_cast <int32_t >(height);
329- image.channels = static_cast <int32_t >(channels);
330- return MultimodalInput (std::move (image));
351+ return MultimodalInput (Image (
352+ std::move (image_data),
353+ static_cast <int32_t >(width),
354+ static_cast <int32_t >(height),
355+ static_cast <int32_t >(channels)));
331356 },
332357 " Create an image input from a torch tensor (H, W, C), (1, H, W, C), (C, H, W), or (1, C, H, W)" ,
333358 py::arg (" image_tensor" ));
0 commit comments