Skip to content

Commit ef274b4

Browse files
committed
feat: decoded media via NIXL
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent b888111 commit ef274b4

File tree

10 files changed

+240
-75
lines changed

10 files changed

+240
-75
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/llm/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ readme.workspace = true
1313
description = "Dynamo LLM Library"
1414

1515
[features]
16-
default = []
16+
default = ["block-manager"]
1717
# todo(ops): get this working in CI as a default.
1818
# default = ["block-manager", "testing-full"]
1919

@@ -142,7 +142,7 @@ json-five = { version = "0.3" }
142142
# media loading in the preprocessor
143143
reqwest = { workspace = true }
144144
base64 = { version = "0.22" }
145-
image = { version = "0.25" }
145+
image = { version = "0.25", features = ["default", "serde"] }
146146
tokio-rayon = {version = "2" }
147147
ndarray = { version = "0.16" }
148148

lib/llm/src/block_manager/storage.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,30 @@ impl SystemStorage {
364364
}
365365
}
366366

367+
impl TryFrom<Vec<u8>> for SystemStorage {
368+
type Error = StorageError;
369+
370+
/// Create SystemStorage from an existing Vec<u8>
371+
/// Takes ownership of the Vec and uses its memory directly (zero-copy)
372+
fn try_from(mut vec: Vec<u8>) -> Result<Self, Self::Error> {
373+
let size = vec.len();
374+
let layout =
375+
Layout::array::<u8>(size).map_err(|e| StorageError::AllocationFailed(e.to_string()))?;
376+
let ptr = NonNull::new(vec.as_mut_ptr())
377+
.ok_or_else(|| StorageError::AllocationFailed("vec pointer is null".into()))?;
378+
379+
// prevents Vec from freeing the memory
380+
std::mem::forget(vec);
381+
382+
Ok(Self {
383+
ptr,
384+
layout,
385+
len: size,
386+
handles: RegistrationHandles::new(),
387+
})
388+
}
389+
}
390+
367391
impl Drop for SystemStorage {
368392
fn drop(&mut self) {
369393
self.handles.release();

lib/llm/src/preprocessor.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,21 @@ impl OpenAIPreprocessor {
327327
// Execute all fetch tasks
328328
if !fetch_tasks.is_empty() {
329329
let loader = self.media_loader.as_ref().unwrap();
330-
let _results = futures::future::join_all(
330+
let results = futures::future::join_all(
331331
fetch_tasks
332332
.iter()
333333
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
334334
)
335335
.await;
336336

337-
// TODO: decode and pass NIXL descriptors to the media map
337+
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
338+
// if one item fails, errors the whole request, other items will be cleaned up by Drop
339+
let rdma_descriptor = result?;
340+
media_map
341+
.entry(type_str)
342+
.or_default()
343+
.push(MultimodalData::Decoded(rdma_descriptor));
344+
}
338345
}
339346

340347
if !media_map.is_empty() {

lib/llm/src/preprocessor/media.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
mod common;
55
mod decoders;
66
mod loader;
7+
mod rdma;
78

89
pub use common::EncodedMediaData;
910
pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
1011
pub use loader::MediaLoader;
12+
pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor, get_nixl_agent, get_nixl_metadata};

lib/llm/src/preprocessor/media/decoders.rs

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,14 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use anyhow::Result;
5+
use serde::{Deserialize, Serialize};
56

67
use super::common::EncodedMediaData;
7-
use ndarray::{ArrayBase, Dimension, OwnedRepr};
8-
mod image;
8+
use super::rdma::DecodedMediaData;
9+
pub mod image;
910

1011
pub use image::{ImageDecoder, ImageMetadata};
1112

12-
#[derive(Debug)]
13-
pub enum DecodedMediaMetadata {
14-
#[allow(dead_code)] // used in followup MR
15-
Image(ImageMetadata),
16-
}
17-
18-
#[derive(Debug, PartialEq, Eq)]
19-
pub enum DataType {
20-
UINT8,
21-
}
22-
23-
// Decoded media data (image RGB, video frames pixels, ...)
24-
#[derive(Debug)]
25-
pub struct DecodedMediaData {
26-
#[allow(dead_code)] // used in followup MR
27-
pub(crate) data: Vec<u8>,
28-
#[allow(dead_code)] // used in followup MR
29-
pub(crate) shape: Vec<usize>,
30-
#[allow(dead_code)] // used in followup MR
31-
pub(crate) dtype: DataType,
32-
#[allow(dead_code)] // used in followup MR
33-
pub(crate) metadata: Option<DecodedMediaMetadata>,
34-
}
35-
36-
// convert Array{N}<u8> to DecodedMediaData
37-
// TODO: Array1<f32> for audio
38-
impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
39-
fn from(array: ArrayBase<OwnedRepr<u8>, D>) -> Self {
40-
let shape = array.shape().to_vec();
41-
let (data, _) = array.into_raw_vec_and_offset();
42-
Self {
43-
data,
44-
shape,
45-
dtype: DataType::UINT8,
46-
metadata: None,
47-
}
48-
}
49-
}
50-
5113
#[async_trait::async_trait]
5214
pub trait Decoder: Clone + Send + 'static {
5315
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
@@ -67,3 +29,9 @@ pub struct MediaDecoder {
6729
pub image_decoder: ImageDecoder,
6830
// TODO: video, audio decoders
6931
}
32+
33+
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
34+
pub enum DecodedMediaMetadata {
35+
#[allow(dead_code)] // used in followup MR
36+
Image(ImageMetadata),
37+
}

lib/llm/src/preprocessor/media/decoders/image.rs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ use std::io::Cursor;
66
use anyhow::Result;
77
use image::{ColorType, GenericImageView, ImageFormat, ImageReader};
88
use ndarray::Array3;
9+
use serde::{Deserialize, Serialize};
910

1011
use super::super::common::EncodedMediaData;
11-
use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata};
12-
use super::Decoder;
12+
use super::super::rdma::DecodedMediaData;
13+
use super::{DecodedMediaMetadata, Decoder};
1314

1415
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
1516

16-
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
17+
#[derive(Clone, Debug, Serialize, Deserialize)]
1718
#[serde(deny_unknown_fields)]
1819
pub struct ImageDecoder {
1920
#[serde(default)]
@@ -35,12 +36,12 @@ impl Default for ImageDecoder {
3536
}
3637
}
3738

38-
#[derive(Debug)]
39+
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
3940
pub enum ImageLayout {
4041
HWC,
4142
}
4243

43-
#[derive(Debug)]
44+
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
4445
pub struct ImageMetadata {
4546
#[allow(dead_code)] // used in followup MR
4647
pub(crate) format: Option<ImageFormat>,
@@ -77,8 +78,8 @@ impl Decoder for ImageDecoder {
7778
let (width, height) = img.dimensions();
7879
let shape = (height as usize, width as usize, n_channels as usize);
7980
let array = Array3::from_shape_vec(shape, data)?;
80-
let mut decoded: DecodedMediaData = array.into();
81-
decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
81+
let mut decoded: DecodedMediaData = array.try_into()?;
82+
decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
8283
format,
8384
color_type,
8485
layout: ImageLayout::HWC,
@@ -89,7 +90,7 @@ impl Decoder for ImageDecoder {
8990

9091
#[cfg(test)]
9192
mod tests {
92-
use super::super::super::decoders::DataType;
93+
use super::super::super::rdma::DataType;
9394
use super::*;
9495
use image::{DynamicImage, ImageBuffer};
9596
use rstest::rstest;
@@ -155,10 +156,10 @@ mod tests {
155156

156157
let decoded = result.unwrap();
157158
assert_eq!(
158-
decoded.shape,
159+
decoded.tensor_info.shape,
159160
vec![height as usize, width as usize, expected_channels as usize]
160161
);
161-
assert_eq!(decoded.dtype, DataType::UINT8);
162+
assert_eq!(decoded.tensor_info.dtype, DataType::UINT8);
162163
}
163164

164165
#[rstest]
@@ -195,9 +196,12 @@ mod tests {
195196
format
196197
);
197198
let decoded = result.unwrap();
198-
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
199199
assert_eq!(
200-
decoded.dtype,
200+
decoded.tensor_info.shape,
201+
vec![height as usize, width as usize, 3]
202+
);
203+
assert_eq!(
204+
decoded.tensor_info.dtype,
201205
DataType::UINT8,
202206
"dtype should be uint8 for case: {}",
203207
test_case
@@ -235,11 +239,15 @@ mod tests {
235239
);
236240

237241
let decoded = result.unwrap();
238-
assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");
239-
assert_eq!(decoded.shape[0], 1, "Height should be 1");
240-
assert_eq!(decoded.shape[1], 1, "Width should be 1");
241242
assert_eq!(
242-
decoded.dtype,
243+
decoded.tensor_info.shape.len(),
244+
3,
245+
"Should have 3 dimensions"
246+
);
247+
assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1");
248+
assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1");
249+
assert_eq!(
250+
decoded.tensor_info.dtype,
243251
DataType::UINT8,
244252
"dtype should be uint8 for {} channels {:?}",
245253
input_channels,

lib/llm/src/preprocessor/media/loader.rs

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ use anyhow::Result;
99
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
1010

1111
use super::common::EncodedMediaData;
12-
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
12+
use super::decoders::{Decoder, MediaDecoder};
13+
use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent};
14+
use nixl_sys::Agent as NixlAgent;
1315

1416
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
1517
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
@@ -39,7 +41,7 @@ pub struct MediaLoader {
3941
media_decoder: MediaDecoder,
4042
http_client: reqwest::Client,
4143
media_fetcher: MediaFetcher,
42-
// TODO: NIXL agent
44+
nixl_agent: NixlAgent,
4345
}
4446

4547
impl MediaLoader {
@@ -53,10 +55,13 @@ impl MediaLoader {
5355

5456
let http_client = http_client_builder.build()?;
5557

58+
let nixl_agent = get_nixl_agent()?;
59+
5660
Ok(Self {
5761
media_decoder,
5862
http_client,
5963
media_fetcher,
64+
nixl_agent,
6065
})
6166
}
6267

@@ -90,9 +95,8 @@ impl MediaLoader {
9095
&self,
9196
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
9297
// TODO: request-level options
93-
) -> Result<DecodedMediaData> {
94-
// fetch the media
95-
// TODO: decode and NIXL-register
98+
) -> Result<RdmaMediaDataDescriptor> {
99+
// fetch the media, decode and NIXL-register
96100
let decoded = match oai_content_part {
97101
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
98102
let url = &image_part.image_url.url;
@@ -112,14 +116,16 @@ impl MediaLoader {
112116
_ => anyhow::bail!("Unsupported media type"),
113117
};
114118

115-
Ok(decoded)
119+
let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
120+
Ok(rdma_descriptor)
116121
}
117122
}
118123

119124
#[cfg(test)]
120125
mod tests {
121-
use super::super::decoders::DataType;
126+
use super::super::rdma::DataType;
122127
use super::*;
128+
use crate::block_manager::storage::nixl::NixlRegisterableStorage;
123129
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
124130

125131
#[tokio::test]
@@ -157,17 +163,52 @@ mod tests {
157163
result.err()
158164
);
159165

160-
let data = result.unwrap();
161-
assert_eq!(data.dtype, DataType::UINT8);
166+
let descriptor = result.unwrap();
167+
assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);
162168

163169
// Verify image dimensions: 1,999px × 1,125px (width × height)
164170
// Shape format is [height, width, channels]
165-
assert_eq!(data.shape.len(), 3);
166-
assert_eq!(data.shape[0], 1125, "Height should be 1125");
167-
assert_eq!(data.shape[1], 1999, "Width should be 1999");
168-
assert_eq!(data.shape[2], 4, "RGBA channels should be 4");
171+
assert_eq!(descriptor.tensor_info.shape.len(), 3);
172+
assert_eq!(
173+
descriptor.tensor_info.shape[0], 1125,
174+
"Height should be 1125"
175+
);
176+
assert_eq!(
177+
descriptor.tensor_info.shape[1], 1999,
178+
"Width should be 1999"
179+
);
180+
assert_eq!(
181+
descriptor.tensor_info.shape[2], 4,
182+
"RGBA channels should be 4"
183+
);
169184

170185
mock.assert_async().await;
186+
187+
assert!(
188+
!descriptor.tensor_info.shape.is_empty(),
189+
"Shape should not be empty"
190+
);
191+
assert_eq!(
192+
descriptor.tensor_info.shape[0], 1125,
193+
"Height should be 1125"
194+
);
195+
assert_eq!(
196+
descriptor.tensor_info.shape[1], 1999,
197+
"Width should be 1999"
198+
);
199+
assert_eq!(
200+
descriptor.tensor_info.shape[2], 4,
201+
"RGBA channels should be 4"
202+
);
203+
204+
assert!(
205+
descriptor.source_storage.is_some(),
206+
"Source storage should be present"
207+
);
208+
assert!(
209+
descriptor.source_storage.unwrap().is_nixl_registered(),
210+
"Source storage should be registered with NIXL"
211+
);
171212
}
172213

173214
#[test]

0 commit comments

Comments
 (0)