Skip to content

Commit 07fb85a

Browse files
committed
feat: decoded media via NIXL
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent ae4b08a commit 07fb85a

File tree

11 files changed

+209
-76
lines changed

11 files changed

+209
-76
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed"]
4747
dynamo-runtime = { path = "lib/runtime", version = "0.6.1" }
4848
dynamo-llm = { path = "lib/llm", version = "0.6.1" }
4949
dynamo-config = { path = "lib/config", version = "0.6.1" }
50+
dynamo-memory = { path = "lib/memory", version = "0.6.1" }
5051
dynamo-tokens = { path = "lib/tokens", version = "0.6.1" }
5152
dynamo-async-openai = { path = "lib/async-openai", version = "0.6.1", features = [
5253
"byot",

lib/llm/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ harness = false
3333
name = "transfer_context_v2"
3434
harness = false
3535
required-features = ["block-manager", "testing-cuda"]
36+
3637
[dependencies]
3738
# repo
3839
dynamo-runtime = { workspace = true }
@@ -41,6 +42,7 @@ dynamo-runtime = { workspace = true }
4142
aho-corasick = "1.1"
4243
anyhow = { workspace = true }
4344
dynamo-async-openai = { workspace = true }
45+
dynamo-memory = { workspace = true }
4446
dynamo-parsers = { workspace = true }
4547
async-stream = { workspace = true }
4648
async-trait = { workspace = true }
@@ -142,7 +144,7 @@ json-five = { version = "0.3" }
142144
# media loading in the preprocessor
143145
reqwest = { workspace = true }
144146
base64 = { version = "0.22" }
145-
image = { version = "0.25" }
147+
image = { version = "0.25", features = ["default", "serde"] }
146148
tokio-rayon = {version = "2" }
147149
ndarray = { version = "0.16" }
148150

lib/llm/src/preprocessor.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,21 @@ impl OpenAIPreprocessor {
327327
// Execute all fetch tasks
328328
if !fetch_tasks.is_empty() {
329329
let loader = self.media_loader.as_ref().unwrap();
330-
let _results = futures::future::join_all(
330+
let results = futures::future::join_all(
331331
fetch_tasks
332332
.iter()
333333
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
334334
)
335335
.await;
336336

337-
// TODO: decode and pass NIXL descriptors to the media map
337+
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
338+
// if one item fails, errors the whole request, other items will be cleaned up by Drop
339+
let rdma_descriptor = result?;
340+
media_map
341+
.entry(type_str)
342+
.or_default()
343+
.push(MultimodalData::Decoded(rdma_descriptor));
344+
}
338345
}
339346

340347
if !media_map.is_empty() {

lib/llm/src/preprocessor/media.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
mod common;
55
mod decoders;
66
mod loader;
7+
mod rdma;
78

89
pub use common::EncodedMediaData;
910
pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
1011
pub use loader::MediaLoader;
12+
pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor, get_nixl_agent, get_nixl_metadata};

lib/llm/src/preprocessor/media/decoders.rs

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,14 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use anyhow::Result;
5+
use serde::{Deserialize, Serialize};
56

67
use super::common::EncodedMediaData;
7-
use ndarray::{ArrayBase, Dimension, OwnedRepr};
8-
mod image;
8+
use super::rdma::DecodedMediaData;
9+
pub mod image;
910

1011
pub use image::{ImageDecoder, ImageMetadata};
1112

12-
#[derive(Debug)]
13-
pub enum DecodedMediaMetadata {
14-
#[allow(dead_code)] // used in followup MR
15-
Image(ImageMetadata),
16-
}
17-
18-
#[derive(Debug, PartialEq, Eq)]
19-
pub enum DataType {
20-
UINT8,
21-
}
22-
23-
// Decoded media data (image RGB, video frames pixels, ...)
24-
#[derive(Debug)]
25-
pub struct DecodedMediaData {
26-
#[allow(dead_code)] // used in followup MR
27-
pub(crate) data: Vec<u8>,
28-
#[allow(dead_code)] // used in followup MR
29-
pub(crate) shape: Vec<usize>,
30-
#[allow(dead_code)] // used in followup MR
31-
pub(crate) dtype: DataType,
32-
#[allow(dead_code)] // used in followup MR
33-
pub(crate) metadata: Option<DecodedMediaMetadata>,
34-
}
35-
36-
// convert Array{N}<u8> to DecodedMediaData
37-
// TODO: Array1<f32> for audio
38-
impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
39-
fn from(array: ArrayBase<OwnedRepr<u8>, D>) -> Self {
40-
let shape = array.shape().to_vec();
41-
let (data, _) = array.into_raw_vec_and_offset();
42-
Self {
43-
data,
44-
shape,
45-
dtype: DataType::UINT8,
46-
metadata: None,
47-
}
48-
}
49-
}
50-
5113
#[async_trait::async_trait]
5214
pub trait Decoder: Clone + Send + 'static {
5315
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
@@ -67,3 +29,9 @@ pub struct MediaDecoder {
6729
pub image_decoder: ImageDecoder,
6830
// TODO: video, audio decoders
6931
}
32+
33+
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
34+
pub enum DecodedMediaMetadata {
35+
#[allow(dead_code)] // used in followup MR
36+
Image(ImageMetadata),
37+
}

lib/llm/src/preprocessor/media/decoders/image.rs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ use std::io::Cursor;
66
use anyhow::Result;
77
use image::{ColorType, GenericImageView, ImageFormat, ImageReader};
88
use ndarray::Array3;
9+
use serde::{Deserialize, Serialize};
910

1011
use super::super::common::EncodedMediaData;
11-
use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata};
12-
use super::Decoder;
12+
use super::super::rdma::DecodedMediaData;
13+
use super::{DecodedMediaMetadata, Decoder};
1314

1415
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
1516

16-
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
17+
#[derive(Clone, Debug, Serialize, Deserialize)]
1718
#[serde(deny_unknown_fields)]
1819
pub struct ImageDecoder {
1920
#[serde(default)]
@@ -36,12 +37,12 @@ impl Default for ImageDecoder {
3637
}
3738

3839
#[allow(clippy::upper_case_acronyms)]
39-
#[derive(Debug)]
40+
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
4041
pub enum ImageLayout {
4142
HWC,
4243
}
4344

44-
#[derive(Debug)]
45+
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
4546
pub struct ImageMetadata {
4647
#[allow(dead_code)] // used in followup MR
4748
pub(crate) format: Option<ImageFormat>,
@@ -78,8 +79,8 @@ impl Decoder for ImageDecoder {
7879
let (width, height) = img.dimensions();
7980
let shape = (height as usize, width as usize, n_channels as usize);
8081
let array = Array3::from_shape_vec(shape, data)?;
81-
let mut decoded: DecodedMediaData = array.into();
82-
decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
82+
let mut decoded: DecodedMediaData = array.try_into()?;
83+
decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
8384
format,
8485
color_type,
8586
layout: ImageLayout::HWC,
@@ -90,7 +91,7 @@ impl Decoder for ImageDecoder {
9091

9192
#[cfg(test)]
9293
mod tests {
93-
use super::super::super::decoders::DataType;
94+
use super::super::super::rdma::DataType;
9495
use super::*;
9596
use image::{DynamicImage, ImageBuffer};
9697
use rstest::rstest;
@@ -156,10 +157,10 @@ mod tests {
156157

157158
let decoded = result.unwrap();
158159
assert_eq!(
159-
decoded.shape,
160+
decoded.tensor_info.shape,
160161
vec![height as usize, width as usize, expected_channels as usize]
161162
);
162-
assert_eq!(decoded.dtype, DataType::UINT8);
163+
assert_eq!(decoded.tensor_info.dtype, DataType::UINT8);
163164
}
164165

165166
#[rstest]
@@ -196,9 +197,12 @@ mod tests {
196197
format
197198
);
198199
let decoded = result.unwrap();
199-
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
200200
assert_eq!(
201-
decoded.dtype,
201+
decoded.tensor_info.shape,
202+
vec![height as usize, width as usize, 3]
203+
);
204+
assert_eq!(
205+
decoded.tensor_info.dtype,
202206
DataType::UINT8,
203207
"dtype should be uint8 for case: {}",
204208
test_case
@@ -236,11 +240,15 @@ mod tests {
236240
);
237241

238242
let decoded = result.unwrap();
239-
assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");
240-
assert_eq!(decoded.shape[0], 1, "Height should be 1");
241-
assert_eq!(decoded.shape[1], 1, "Width should be 1");
242243
assert_eq!(
243-
decoded.dtype,
244+
decoded.tensor_info.shape.len(),
245+
3,
246+
"Should have 3 dimensions"
247+
);
248+
assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1");
249+
assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1");
250+
assert_eq!(
251+
decoded.tensor_info.dtype,
244252
DataType::UINT8,
245253
"dtype should be uint8 for {} channels {:?}",
246254
input_channels,

lib/llm/src/preprocessor/media/loader.rs

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ use anyhow::Result;
99
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
1010

1111
use super::common::EncodedMediaData;
12-
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
12+
use super::decoders::{Decoder, MediaDecoder};
13+
use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent};
14+
use dynamo_memory::nixl::NixlAgent;
1315

1416
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
1517
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
@@ -39,7 +41,7 @@ pub struct MediaLoader {
3941
media_decoder: MediaDecoder,
4042
http_client: reqwest::Client,
4143
media_fetcher: MediaFetcher,
42-
// TODO: NIXL agent
44+
nixl_agent: NixlAgent,
4345
}
4446

4547
impl MediaLoader {
@@ -53,10 +55,13 @@ impl MediaLoader {
5355

5456
let http_client = http_client_builder.build()?;
5557

58+
let nixl_agent = get_nixl_agent()?;
59+
5660
Ok(Self {
5761
media_decoder,
5862
http_client,
5963
media_fetcher,
64+
nixl_agent,
6065
})
6166
}
6267

@@ -90,9 +95,8 @@ impl MediaLoader {
9095
&self,
9196
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
9297
// TODO: request-level options
93-
) -> Result<DecodedMediaData> {
94-
// fetch the media
95-
// TODO: decode and NIXL-register
98+
) -> Result<RdmaMediaDataDescriptor> {
99+
// fetch the media, decode and NIXL-register
96100
let decoded = match oai_content_part {
97101
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
98102
let url = &image_part.image_url.url;
@@ -112,13 +116,14 @@ impl MediaLoader {
112116
_ => anyhow::bail!("Unsupported media type"),
113117
};
114118

115-
Ok(decoded)
119+
let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
120+
Ok(rdma_descriptor)
116121
}
117122
}
118123

119124
#[cfg(test)]
120125
mod tests {
121-
use super::super::decoders::DataType;
126+
use super::super::rdma::DataType;
122127
use super::*;
123128
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
124129

@@ -157,17 +162,33 @@ mod tests {
157162
result.err()
158163
);
159164

160-
let data = result.unwrap();
161-
assert_eq!(data.dtype, DataType::UINT8);
165+
let descriptor = result.unwrap();
166+
assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);
162167

163168
// Verify image dimensions: 1,999px × 1,125px (width × height)
164169
// Shape format is [height, width, channels]
165-
assert_eq!(data.shape.len(), 3);
166-
assert_eq!(data.shape[0], 1125, "Height should be 1125");
167-
assert_eq!(data.shape[1], 1999, "Width should be 1999");
168-
assert_eq!(data.shape[2], 4, "RGBA channels should be 4");
170+
assert_eq!(descriptor.tensor_info.shape.len(), 3);
171+
assert_eq!(
172+
descriptor.tensor_info.shape[0], 1125,
173+
"Height should be 1125"
174+
);
175+
assert_eq!(
176+
descriptor.tensor_info.shape[1], 1999,
177+
"Width should be 1999"
178+
);
179+
assert_eq!(
180+
descriptor.tensor_info.shape[2], 4,
181+
"RGBA channels should be 4"
182+
);
169183

170-
mock.assert_async().await;
184+
assert!(
185+
descriptor.source_storage.is_some(),
186+
"Source storage should be present"
187+
);
188+
assert!(
189+
descriptor.source_storage.unwrap().is_registered(),
190+
"Source storage should be registered with NIXL"
191+
);
171192
}
172193

173194
#[test]

0 commit comments

Comments
 (0)