Skip to content

Commit 6a0920b

Browse files
committed
feat: decoded media via NIXL
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent cbe3353 commit 6a0920b

File tree

9 files changed

+165
-43
lines changed

9 files changed

+165
-43
lines changed

lib/llm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ readme.workspace = true
1313
description = "Dynamo LLM Library"
1414

1515
[features]
16-
default = []
16+
default = ["block-manager"]
1717
# todo(ops): get this working in CI as a default.
1818
# default = ["block-manager", "testing-full"]
1919

lib/llm/src/block_manager/storage.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,30 @@ impl SystemStorage {
364364
}
365365
}
366366

367+
impl TryFrom<Vec<u8>> for SystemStorage {
368+
type Error = StorageError;
369+
370+
/// Create SystemStorage from an existing Vec<u8>
371+
/// Takes ownership of the Vec and uses its memory directly (zero-copy)
372+
fn try_from(mut vec: Vec<u8>) -> Result<Self, Self::Error> {
373+
let size = vec.len();
374+
let layout =
375+
Layout::array::<u8>(size).map_err(|e| StorageError::AllocationFailed(e.to_string()))?;
376+
let ptr = NonNull::new(vec.as_mut_ptr())
377+
.ok_or_else(|| StorageError::AllocationFailed("vec pointer is null".into()))?;
378+
379+
// prevents Vec from freeing the memory
380+
std::mem::forget(vec);
381+
382+
Ok(Self {
383+
ptr,
384+
layout,
385+
len: size,
386+
handles: RegistrationHandles::new(),
387+
})
388+
}
389+
}
390+
367391
impl Drop for SystemStorage {
368392
fn drop(&mut self) {
369393
self.handles.release();

lib/llm/src/preprocessor.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use std::{collections::HashMap, pin::Pin, sync::Arc};
2727
use tracing;
2828

2929
use crate::model_card::{ModelDeploymentCard, ModelInfo};
30-
use crate::preprocessor::media::MediaLoader;
30+
use crate::preprocessor::media::{MediaDecoder, MediaLoader};
3131
use crate::preprocessor::prompt::OAIChatLikeRequest;
3232
use crate::protocols::common::preprocessor::{
3333
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder,
@@ -143,7 +143,7 @@ impl OpenAIPreprocessor {
143143

144144
// // Initialize runtime config from the ModelDeploymentCard
145145
let runtime_config = mdc.runtime_config.clone();
146-
let media_loader = None; // TODO: enable with decoder config from MDC
146+
let media_loader = Some(MediaLoader::new(MediaDecoder::default())?); // TODO: enable with decoder config from MDC
147147
Ok(Arc::new(Self {
148148
formatter,
149149
tokenizer,
@@ -327,14 +327,21 @@ impl OpenAIPreprocessor {
327327
// Execute all fetch tasks
328328
if !fetch_tasks.is_empty() {
329329
let loader = self.media_loader.as_ref().unwrap();
330-
let _results = futures::future::join_all(
330+
let results = futures::future::join_all(
331331
fetch_tasks
332332
.iter()
333333
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
334334
)
335335
.await;
336336

337-
// TODO: decode and pass NIXL descriptors to the media map
337+
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
338+
// if one item fails, errors the whole request, other items will be cleaned up by Drop
339+
let rdma_descriptor = result?;
340+
media_map
341+
.entry(type_str)
342+
.or_default()
343+
.push(MultimodalData::Decoded(rdma_descriptor));
344+
}
338345
}
339346

340347
if !media_map.is_empty() {

lib/llm/src/preprocessor/media.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
mod common;
55
mod decoders;
66
mod loader;
7+
mod rdma;
78

89
pub use common::EncodedMediaData;
910
pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
1011
pub use loader::MediaLoader;
12+
pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor, get_nixl_agent, get_nixl_metadata};

lib/llm/src/preprocessor/media/decoders.rs

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,11 @@
44
use anyhow::Result;
55

66
use super::common::EncodedMediaData;
7-
use ndarray::{ArrayBase, Dimension, OwnedRepr};
7+
use super::rdma::DecodedMediaData;
88
mod image;
99

1010
pub use image::ImageDecoder;
1111

12-
// Decoded media data (image RGB, video frames pixels, ...)
13-
#[derive(Debug)]
14-
pub struct DecodedMediaData {
15-
pub(crate) data: Vec<u8>,
16-
pub(crate) shape: Vec<usize>,
17-
pub(crate) dtype: String,
18-
}
19-
20-
// convert Array{N}<u8> to DecodedMediaData
21-
// TODO: Array1<f32> for audio
22-
impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
23-
fn from(array: ArrayBase<OwnedRepr<u8>, D>) -> Self {
24-
let shape = array.shape().to_vec();
25-
let (data, _) = array.into_raw_vec_and_offset();
26-
Self {
27-
data,
28-
shape,
29-
dtype: "uint8".to_string(),
30-
}
31-
}
32-
}
33-
3412
#[async_trait::async_trait]
3513
pub trait Decoder: Clone + Send + 'static {
3614
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;

lib/llm/src/preprocessor/media/decoders/image.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use image::GenericImageView;
66
use ndarray::Array3;
77

88
use super::super::common::EncodedMediaData;
9-
use super::super::decoders::DecodedMediaData;
9+
use super::super::rdma::DecodedMediaData;
1010
use super::Decoder;
1111

1212
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
@@ -38,7 +38,7 @@ impl Decoder for ImageDecoder {
3838
};
3939
let shape = (height as usize, width as usize, n_channels as usize);
4040
let array = Array3::from_shape_vec(shape, data)?;
41-
Ok(array.into())
41+
Ok(array.try_into()?)
4242
}
4343
}
4444

lib/llm/src/preprocessor/media/loader.rs

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ use anyhow::Result;
66
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
77

88
use super::common::EncodedMediaData;
9-
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
9+
use super::decoders::{Decoder, MediaDecoder};
10+
use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent};
11+
use nixl_sys::Agent as NixlAgent;
1012

1113
// TODO: make this configurable
1214
const HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
1315

1416
pub struct MediaLoader {
1517
media_decoder: MediaDecoder,
1618
http_client: reqwest::Client,
17-
// TODO: NIXL agent
19+
nixl_agent: NixlAgent,
1820
}
1921

2022
impl MediaLoader {
@@ -23,19 +25,21 @@ impl MediaLoader {
2325
.user_agent(HTTP_USER_AGENT)
2426
.build()?;
2527

28+
let nixl_agent = get_nixl_agent()?;
29+
2630
Ok(Self {
2731
media_decoder,
2832
http_client,
33+
nixl_agent,
2934
})
3035
}
3136

3237
pub async fn fetch_and_decode_media_part(
3338
&self,
3439
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
3540
// TODO: request-level options
36-
) -> Result<DecodedMediaData> {
37-
// fetch the media
38-
// TODO: decode and NIXL-register
41+
) -> Result<RdmaMediaDataDescriptor> {
42+
// fetch the media, decode and NIXL-register
3943
let decoded = match oai_content_part {
4044
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
4145
let url = &image_part.image_url.url;
@@ -53,13 +57,15 @@ impl MediaLoader {
5357
_ => anyhow::bail!("Unsupported media type"),
5458
};
5559

56-
Ok(decoded)
60+
let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
61+
Ok(rdma_descriptor)
5762
}
5863
}
5964

6065
#[cfg(test)]
6166
mod tests {
6267
use super::*;
68+
use crate::block_manager::storage::nixl::NixlRegisterableStorage;
6369
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
6470

6571
// warning: non-airgap test
@@ -82,14 +88,23 @@ mod tests {
8288
result.err()
8389
);
8490

85-
let data = result.unwrap();
86-
assert_eq!(data.dtype, "uint8");
91+
let descriptor = result.unwrap();
92+
assert_eq!(descriptor.dtype, "uint8");
8793

8894
// Verify image dimensions: 1,999px × 1,125px (width × height)
8995
// Shape format is [height, width, channels]
90-
assert!(!data.shape.is_empty(), "Shape should not be empty");
91-
assert_eq!(data.shape[0], 1125, "Height should be 1125");
92-
assert_eq!(data.shape[1], 1999, "Width should be 1999");
93-
assert_eq!(data.shape[2], 4, "RGBA channels should be 4");
96+
assert!(!descriptor.shape.is_empty(), "Shape should not be empty");
97+
assert_eq!(descriptor.shape[0], 1125, "Height should be 1125");
98+
assert_eq!(descriptor.shape[1], 1999, "Width should be 1999");
99+
assert_eq!(descriptor.shape[2], 4, "RGBA channels should be 4");
100+
101+
assert!(
102+
descriptor.source_storage.is_some(),
103+
"Source storage should be present"
104+
);
105+
assert!(
106+
descriptor.source_storage.unwrap().is_nixl_registered(),
107+
"Source storage should be registered with NIXL"
108+
);
94109
}
95110
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use anyhow::Result;
5+
use base64::{Engine as _, engine::general_purpose};
6+
use ndarray::{ArrayBase, Dimension, OwnedRepr};
7+
use serde::{Deserialize, Serialize};
8+
use std::sync::Arc;
9+
10+
use crate::block_manager::storage::{
11+
StorageError, SystemStorage, nixl::NixlRegisterableStorage, nixl::NixlStorage,
12+
};
13+
use nixl_sys::{Agent as NixlAgent, MemType, RegDescList};
14+
15+
// Decoded media data (image RGB, video frames pixels, ...)
16+
#[derive(Debug)]
17+
pub struct DecodedMediaData {
18+
pub(crate) data: SystemStorage,
19+
pub(crate) shape: Vec<usize>,
20+
pub(crate) dtype: String,
21+
}
22+
23+
// Decoded media data NIXL descriptor (sent to the next step in the pipeline / NATS)
24+
#[derive(Serialize, Deserialize, Clone, Debug)]
25+
pub struct RdmaMediaDataDescriptor {
26+
// b64 agent metadata
27+
pub(crate) nixl_metadata: String,
28+
// tensor descriptor
29+
pub(crate) nixl_descriptor: NixlStorage,
30+
pub(crate) shape: Vec<usize>,
31+
pub(crate) dtype: String,
32+
// reference to the actual data, kept alive while the rdma descriptor is alive
33+
#[serde(skip, default)]
34+
#[allow(dead_code)]
35+
pub(crate) source_storage: Option<Arc<SystemStorage>>,
36+
}
37+
38+
impl DecodedMediaData {
39+
pub fn into_rdma_descriptor(self, nixl_agent: &NixlAgent) -> Result<RdmaMediaDataDescriptor> {
40+
// get NIXL metadata and descriptor
41+
let mut source_storage = self.data;
42+
source_storage.nixl_register(nixl_agent, None)?;
43+
let nixl_descriptor = unsafe { source_storage.as_nixl_descriptor() }
44+
.ok_or_else(|| anyhow::anyhow!("Cannot convert storage to NIXL descriptor"))?;
45+
46+
let nixl_metadata = get_nixl_metadata(nixl_agent, &source_storage)?;
47+
Ok(RdmaMediaDataDescriptor {
48+
nixl_metadata,
49+
nixl_descriptor,
50+
shape: self.shape,
51+
dtype: self.dtype,
52+
// do not drop / free the storage yet
53+
source_storage: Some(Arc::new(source_storage)),
54+
})
55+
}
56+
}
57+
58+
// convert Array{N}<u8> to DecodedMediaData
59+
// TODO: Array1<f32> for audio
60+
impl<D: Dimension> TryFrom<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
61+
type Error = StorageError;
62+
63+
fn try_from(array: ArrayBase<OwnedRepr<u8>, D>) -> Result<Self, Self::Error> {
64+
let shape = array.shape().to_vec();
65+
let (data, _) = array.into_raw_vec_and_offset();
66+
Ok(Self {
67+
data: SystemStorage::try_from(data)?,
68+
shape,
69+
dtype: "uint8".to_string(),
70+
})
71+
}
72+
}
73+
74+
// Get NIXL metadata for a descriptor
75+
// Avoids cross-request leak possibility and reduces metadata size
76+
// TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target?
77+
pub fn get_nixl_metadata(agent: &NixlAgent, storage: &SystemStorage) -> Result<String> {
78+
// WAR: Until https://github.com/ai-dynamo/nixl/pull/970 is merged, can't use get_local_partial_md
79+
let nixl_md = agent.get_local_md()?;
80+
// let mut reg_desc_list = RegDescList::new(MemType::Dram)?;
81+
// reg_desc_list.add_storage_desc(storage)?;
82+
// let nixl_partial_md = agent.get_local_partial_md(&reg_desc_list, None)?;
83+
84+
let b64_encoded = general_purpose::STANDARD.encode(&nixl_md);
85+
Ok(format!("b64:{}", b64_encoded))
86+
}
87+
88+
pub fn get_nixl_agent() -> Result<NixlAgent> {
89+
let uuid = uuid::Uuid::new_v4();
90+
let nixl_agent = NixlAgent::new(&format!("media-loader-{}", uuid))?;
91+
let (_, ucx_params) = nixl_agent.get_plugin_params("UCX")?;
92+
nixl_agent.create_backend("UCX", &ucx_params)?;
93+
Ok(nixl_agent)
94+
}

lib/llm/src/protocols/common/preprocessor.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ use serde::{Deserialize, Serialize};
66

77
use super::{OutputOptions, SamplingOptions, StopConditions};
88
use crate::kv_router::RouterConfigOverride;
9+
use crate::preprocessor::media::RdmaMediaDataDescriptor;
910
use crate::protocols::TokenIdType;
1011

1112
#[derive(Serialize, Deserialize, Debug, Clone)]
1213
pub enum MultimodalData {
1314
Url(url::Url),
14-
// TODO: Decoded(DecodedMediaData),
15+
Decoded(RdmaMediaDataDescriptor),
1516
}
1617

1718
// multimodal map containing {mm_part_type: [data...]}
@@ -31,6 +32,7 @@ pub struct PreprocessedRequest {
3132
#[builder(default)]
3233
#[serde(default, skip_serializing_if = "Option::is_none")]
3334
pub multi_modal_data: Option<MultimodalDataMap>,
35+
3436
/// StopConditions are conditions that the inference engine will use to stop generation.
3537
pub stop_conditions: StopConditions,
3638

0 commit comments

Comments
 (0)