diff --git a/Cargo.lock b/Cargo.lock index 3d8423ccd2..e4a59896db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2234,6 +2234,7 @@ dependencies = [ "derive_builder", "dialoguer", "dynamo-async-openai", + "dynamo-memory", "dynamo-parsers", "dynamo-runtime", "either", @@ -3923,6 +3924,7 @@ dependencies = [ "ravif", "rayon", "rgb", + "serde", "tiff", "zune-core", "zune-jpeg", @@ -5341,9 +5343,9 @@ dependencies = [ [[package]] name = "nixl-sys" version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a73b92494c94b2ff2d004cd9274d966863089e867dc9cd98bc640aefe7622036" +source = "git+https://github.com/ai-dynamo/nixl?rev=ae3f8af#ae3f8af9508a1e1f8aeb687ae3ae66644d3ba5e8" dependencies = [ + "anyhow", "bindgen 0.71.1", "cc", "libc", diff --git a/Cargo.toml b/Cargo.toml index 3b813ef3e9..041bff3bbd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed"] dynamo-runtime = { path = "lib/runtime", version = "0.6.1" } dynamo-llm = { path = "lib/llm", version = "0.6.1" } dynamo-config = { path = "lib/config", version = "0.6.1" } +dynamo-memory = { path = "lib/memory", version = "0.6.1" } dynamo-tokens = { path = "lib/tokens", version = "0.6.1" } dynamo-async-openai = { path = "lib/async-openai", version = "0.6.1", features = [ "byot", diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 3878f6e2a6..0c7083b237 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -358,7 +358,7 @@ version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" dependencies = [ - "bindgen", + "bindgen 0.69.5", "cc", "cmake", "dunce", @@ -558,6 +558,26 @@ dependencies = [ "which", ] +[[package]] +name = "bindgen" +version = "0.71.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +dependencies = [ + "bitflags 2.9.3", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 2.1.1", + "shlex", + "syn 2.0.106", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -1443,6 +1463,13 @@ dependencies = [ "uuid", ] +[[package]] +name = "dynamo-config" +version = "0.6.1" +dependencies = [ + "anyhow", +] + [[package]] name = "dynamo-llm" version = "0.6.1" @@ -1471,6 +1498,7 @@ dependencies = [ "derive_builder", "dialoguer", "dynamo-async-openai", + "dynamo-memory", "dynamo-parsers", "dynamo-runtime", "either", @@ -1527,6 +1555,22 @@ dependencies = [ "zeromq", ] +[[package]] +name = "dynamo-memory" +version = "0.6.1" +dependencies = [ + "anyhow", + "cudarc", + "dynamo-config", + "libc", + "nix 0.30.1", + "nixl-sys", + "offset-allocator", + "serde", + "thiserror 2.0.16", + "tracing", +] + [[package]] name = "dynamo-parsers" version = "0.6.1" @@ -1610,7 +1654,7 @@ dependencies = [ "local-ip-address", "log", "nid", - "nix", + "nix 0.29.0", "nuid", "once_cell", "opentelemetry", @@ -2790,6 +2834,7 @@ dependencies = [ "ravif", "rayon", "rgb", + "serde", "tiff", "zune-core", "zune-jpeg", @@ -3627,6 +3672,34 @@ dependencies = [ "libc", ] +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags 2.9.3", + "cfg-if 1.0.3", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nixl-sys" +version = "0.7.0" +source = "git+https://github.com/ai-dynamo/nixl?rev=ae3f8af#ae3f8af9508a1e1f8aeb687ae3ae66644d3ba5e8" +dependencies = [ + "anyhow", + "bindgen 0.71.1", + "cc", + "libc", + "os_info", + "pkg-config", + "serde", + "thiserror 2.0.16", + "tracing", +] + [[package]] name = "nkeys" version = "0.4.5" @@ -4003,6 +4076,18 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "os_info" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0e1ac5fde8d43c34139135df8ea9ee9465394b2d8d20f032d38998f64afffc3" +dependencies = [ + "log", + "plist", + "serde", + "windows-sys 0.52.0", +] + [[package]] name = "overload" version = "0.1.1" @@ -4228,6 +4313,19 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plist" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "740ebea15c5d1428f910cd1a5f52cebf8d25006245ed8ade92702f4943d91e07" +dependencies = [ + "base64 0.22.1", + "indexmap 2.11.0", + "quick-xml", + "serde", + "time", +] + [[package]] name = "png" version = "0.17.16" @@ -4614,6 +4712,15 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" +[[package]] +name = "quick-xml" +version = "0.38.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" +dependencies = [ + "memchr", +] + [[package]] name = "quinn" version = "0.11.9" diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index ecc4a20d82..84a3cc0f96 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -33,6 +33,7 @@ harness = false name = "transfer_context_v2" harness = false required-features = ["block-manager", "testing-cuda"] + [dependencies] # repo dynamo-runtime = { workspace = true } @@ -41,6 +42,7 @@ dynamo-runtime = { workspace = true } aho-corasick = "1.1" anyhow = { workspace = true } dynamo-async-openai = { workspace = true } +dynamo-memory = { workspace = true } dynamo-parsers = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } @@ -95,7 +97,7 @@ dialoguer = { version = "0.11", default-features = false, features = [ # block_manager aligned-vec = { version = "0.6.4", optional = true } -nixl-sys = { version = "=0.7.0", optional = true } +nixl-sys = { git = "https://github.com/ai-dynamo/nixl", rev = "ae3f8af", optional = true } cudarc = { workspace = true, optional = true } nix = { version = "0.26", optional = true } @@ -142,7 +144,7 @@ json-five = { version = "0.3" } # media loading in the preprocessor reqwest = { workspace = true } base64 = { version = "0.22" } -image = { version = "0.25" } +image = { version = "0.25", features = ["default", "serde"] } tokio-rayon = {version = "2" } ndarray = { version = "0.16" } diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 9fda891fec..13f430482d 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -327,14 +327,21 @@ impl OpenAIPreprocessor { // Execute all fetch tasks if !fetch_tasks.is_empty() { let loader = self.media_loader.as_ref().unwrap(); - let _results = futures::future::join_all( + let results = futures::future::join_all( fetch_tasks .iter() .map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)), ) .await; - // TODO: decode and pass NIXL descriptors to the media map + for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) { + // if one item fails, errors the whole request, other items will be cleaned up by Drop + let rdma_descriptor = result?; + media_map + .entry(type_str) + .or_default() + .push(MultimodalData::Decoded(rdma_descriptor)); + } } if !media_map.is_empty() { diff --git a/lib/llm/src/preprocessor/media.rs b/lib/llm/src/preprocessor/media.rs index 5104af8e21..ce0a99a155 100644 --- a/lib/llm/src/preprocessor/media.rs +++ b/lib/llm/src/preprocessor/media.rs @@ -4,7 +4,9 @@ mod common; mod decoders; mod loader; +mod rdma; pub use common::EncodedMediaData; pub use decoders::{Decoder, ImageDecoder, MediaDecoder}; pub use loader::MediaLoader; +pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor, get_nixl_agent, get_nixl_metadata}; diff --git a/lib/llm/src/preprocessor/media/decoders.rs b/lib/llm/src/preprocessor/media/decoders.rs index aa546915ec..230095975e 100644 --- a/lib/llm/src/preprocessor/media/decoders.rs +++ b/lib/llm/src/preprocessor/media/decoders.rs @@ -2,52 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 use anyhow::Result; +use serde::{Deserialize, Serialize}; use super::common::EncodedMediaData; -use ndarray::{ArrayBase, Dimension, OwnedRepr}; -mod image; +use super::rdma::DecodedMediaData; +pub mod image; pub use image::{ImageDecoder, ImageMetadata}; -#[derive(Debug)] -pub enum DecodedMediaMetadata { - #[allow(dead_code)] // used in followup MR - Image(ImageMetadata), -} - -#[derive(Debug, PartialEq, Eq)] -pub enum DataType { - UINT8, -} - -// Decoded media data (image RGB, video frames pixels, ...) -#[derive(Debug)] -pub struct DecodedMediaData { - #[allow(dead_code)] // used in followup MR - pub(crate) data: Vec, - #[allow(dead_code)] // used in followup MR - pub(crate) shape: Vec, - #[allow(dead_code)] // used in followup MR - pub(crate) dtype: DataType, - #[allow(dead_code)] // used in followup MR - pub(crate) metadata: Option, -} - -// convert Array{N} to DecodedMediaData -// TODO: Array1 for audio -impl From, D>> for DecodedMediaData { - fn from(array: ArrayBase, D>) -> Self { - let shape = array.shape().to_vec(); - let (data, _) = array.into_raw_vec_and_offset(); - Self { - data, - shape, - dtype: DataType::UINT8, - metadata: None, - } - } -} - #[async_trait::async_trait] pub trait Decoder: Clone + Send + 'static { fn decode(&self, data: EncodedMediaData) -> Result; @@ -67,3 +29,9 @@ pub struct MediaDecoder { pub image_decoder: ImageDecoder, // TODO: video, audio decoders } + +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] +pub enum DecodedMediaMetadata { + #[allow(dead_code)] // used in followup MR + Image(ImageMetadata), +} diff --git a/lib/llm/src/preprocessor/media/decoders/image.rs b/lib/llm/src/preprocessor/media/decoders/image.rs index e6c857d33b..fa915f7b34 100644 --- a/lib/llm/src/preprocessor/media/decoders/image.rs +++ b/lib/llm/src/preprocessor/media/decoders/image.rs @@ -6,14 +6,15 @@ use std::io::Cursor; use anyhow::Result; use image::{ColorType, GenericImageView, ImageFormat, ImageReader}; use ndarray::Array3; +use serde::{Deserialize, Serialize}; use super::super::common::EncodedMediaData; -use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata}; -use super::Decoder; +use super::super::rdma::DecodedMediaData; +use super::{DecodedMediaMetadata, Decoder}; const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct ImageDecoder { #[serde(default)] @@ -36,18 +37,16 @@ impl Default for ImageDecoder { } #[allow(clippy::upper_case_acronyms)] -#[derive(Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub enum ImageLayout { HWC, } -#[derive(Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub struct ImageMetadata { #[allow(dead_code)] // used in followup MR pub(crate) format: Option, #[allow(dead_code)] // used in followup MR - pub(crate) color_type: ColorType, - #[allow(dead_code)] // used in followup MR pub(crate) layout: ImageLayout, } @@ -78,10 +77,9 @@ impl Decoder for ImageDecoder { let (width, height) = img.dimensions(); let shape = (height as usize, width as usize, n_channels as usize); let array = Array3::from_shape_vec(shape, data)?; - let mut decoded: DecodedMediaData = array.into(); - decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { + let mut decoded: DecodedMediaData = array.try_into()?; + decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { format, - color_type, layout: ImageLayout::HWC, })); Ok(decoded) @@ -90,7 +88,7 @@ impl Decoder for ImageDecoder { #[cfg(test)] mod tests { - use super::super::super::decoders::DataType; + use super::super::super::rdma::DataType; use super::*; use image::{DynamicImage, ImageBuffer}; use rstest::rstest; @@ -156,10 +154,10 @@ mod tests { let decoded = result.unwrap(); assert_eq!( - decoded.shape, + decoded.tensor_info.shape, vec![height as usize, width as usize, expected_channels as usize] ); - assert_eq!(decoded.dtype, DataType::UINT8); + assert_eq!(decoded.tensor_info.dtype, DataType::UINT8); } #[rstest] @@ -196,9 +194,12 @@ mod tests { format ); let decoded = result.unwrap(); - assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]); assert_eq!( - decoded.dtype, + decoded.tensor_info.shape, + vec![height as usize, width as usize, 3] + ); + assert_eq!( + decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for case: {}", test_case @@ -236,11 +237,15 @@ mod tests { ); let decoded = result.unwrap(); - assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions"); - assert_eq!(decoded.shape[0], 1, "Height should be 1"); - assert_eq!(decoded.shape[1], 1, "Width should be 1"); assert_eq!( - decoded.dtype, + decoded.tensor_info.shape.len(), + 3, + "Should have 3 dimensions" + ); + assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1"); + assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1"); + assert_eq!( + decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for {} channels {:?}", input_channels, diff --git a/lib/llm/src/preprocessor/media/loader.rs b/lib/llm/src/preprocessor/media/loader.rs index 91fc65d9bc..7ddb7b28c5 100644 --- a/lib/llm/src/preprocessor/media/loader.rs +++ b/lib/llm/src/preprocessor/media/loader.rs @@ -9,7 +9,9 @@ use anyhow::Result; use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart; use super::common::EncodedMediaData; -use super::decoders::{DecodedMediaData, Decoder, MediaDecoder}; +use super::decoders::{Decoder, MediaDecoder}; +use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent}; +use dynamo_memory::nixl::NixlAgent; const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo"; const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); @@ -39,7 +41,7 @@ pub struct MediaLoader { media_decoder: MediaDecoder, http_client: reqwest::Client, media_fetcher: MediaFetcher, - // TODO: NIXL agent + nixl_agent: Option, } impl MediaLoader { @@ -53,10 +55,21 @@ impl MediaLoader { let http_client = http_client_builder.build()?; + let nixl_agent = match get_nixl_agent() { + Ok(agent) => Some(agent), + Err(e) => { + tracing::warn!( + "Error when creating NIXL agent (will not be able to register media data): {e}" + ); + None + } + }; + Ok(Self { media_decoder, http_client, media_fetcher, + nixl_agent, }) } @@ -90,9 +103,12 @@ impl MediaLoader { &self, oai_content_part: &ChatCompletionRequestUserMessageContentPart, // TODO: request-level options - ) -> Result { - // fetch the media - // TODO: decode and NIXL-register + ) -> Result { + if self.nixl_agent.is_none() { + anyhow::bail!("NIXL agent is not available, cannot decode and register media data"); + } + + // fetch the media, decode and NIXL-register let decoded = match oai_content_part { ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { let url = &image_part.image_url.url; @@ -112,13 +128,14 @@ impl MediaLoader { _ => anyhow::bail!("Unsupported media type"), }; - Ok(decoded) + let rdma_descriptor = decoded.into_rdma_descriptor(self.nixl_agent.as_ref().unwrap())?; + Ok(rdma_descriptor) } } #[cfg(test)] mod tests { - use super::super::decoders::DataType; + use super::super::rdma::DataType; use super::*; use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl}; @@ -143,7 +160,7 @@ mod tests { ..Default::default() }; - let loader = MediaLoader::new(media_decoder, fetcher).unwrap(); + let loader: MediaLoader = MediaLoader::new(media_decoder, fetcher).unwrap(); let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url())); let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl( @@ -151,23 +168,40 @@ mod tests { ); let result = loader.fetch_and_decode_media_part(&content_part).await; - assert!( - result.is_ok(), - "Failed to fetch and decode image: {:?}", - result.err() - ); - - let data = result.unwrap(); - assert_eq!(data.dtype, DataType::UINT8); + let descriptor = match result { + Ok(descriptor) => descriptor, + Err(e) if e.to_string().contains("NIXL agent is not available") => { + eprintln!("Skipping test: NIXL agent not available"); + return; + } + Err(e) => panic!("Failed to fetch and decode image: {}", e), + }; + assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8); // Verify image dimensions: 1,999px × 1,125px (width × height) // Shape format is [height, width, channels] - assert_eq!(data.shape.len(), 3); - assert_eq!(data.shape[0], 1125, "Height should be 1125"); - assert_eq!(data.shape[1], 1999, "Width should be 1999"); - assert_eq!(data.shape[2], 4, "RGBA channels should be 4"); + assert_eq!(descriptor.tensor_info.shape.len(), 3); + assert_eq!( + descriptor.tensor_info.shape[0], 1125, + "Height should be 1125" + ); + assert_eq!( + descriptor.tensor_info.shape[1], 1999, + "Width should be 1999" + ); + assert_eq!( + descriptor.tensor_info.shape[2], 4, + "RGBA channels should be 4" + ); - mock.assert_async().await; + assert!( + descriptor.source_storage.is_some(), + "Source storage should be present" + ); + assert!( + descriptor.source_storage.unwrap().is_registered(), + "Source storage should be registered with NIXL" + ); } #[test] diff --git a/lib/llm/src/preprocessor/media/rdma.rs b/lib/llm/src/preprocessor/media/rdma.rs new file mode 100644 index 0000000000..7b39e45111 --- /dev/null +++ b/lib/llm/src/preprocessor/media/rdma.rs @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use base64::{Engine as _, engine::general_purpose}; +use ndarray::{ArrayBase, Dimension, OwnedRepr}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use dynamo_memory::{ + StorageError, SystemStorage, + nixl::{self, NixlAgent, NixlDescriptor, RegisteredView}, +}; + +use super::decoders::DecodedMediaMetadata; + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum DataType { + UINT8, +} + +// Common tensor metadata shared between decoded and RDMA descriptors +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct MediaTensorInfo { + pub(crate) shape: Vec, + pub(crate) dtype: DataType, + pub(crate) metadata: Option, +} + +// Decoded media data (image RGB, video frames pixels, ...) +#[derive(Debug)] +pub struct DecodedMediaData { + pub(crate) data: SystemStorage, + pub(crate) tensor_info: MediaTensorInfo, +} + +// Decoded media data NIXL descriptor (sent to the next step in the pipeline / NATS) +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct RdmaMediaDataDescriptor { + // b64 agent metadata + pub(crate) nixl_metadata: String, + // tensor descriptor + pub(crate) nixl_descriptor: NixlDescriptor, + + #[serde(flatten)] + pub(crate) tensor_info: MediaTensorInfo, + + // reference to the actual data, kept alive while the rdma descriptor is alive + #[serde(skip, default)] + #[allow(dead_code)] + pub(crate) source_storage: Option>>, +} + +impl DecodedMediaData { + pub fn into_rdma_descriptor(self, nixl_agent: &NixlAgent) -> Result { + // Register storage with NIXL + let source_storage = self.data; + let registered = nixl::register_with_nixl(source_storage, nixl_agent, None) + .map_err(|_| anyhow::anyhow!("Failed to register storage with NIXL"))?; + + let nixl_descriptor = registered.descriptor(); + let nixl_metadata = get_nixl_metadata(nixl_agent, registered.storage())?; + + Ok(RdmaMediaDataDescriptor { + nixl_metadata, + nixl_descriptor, + tensor_info: self.tensor_info, + // Keep registered storage alive + source_storage: Some(Arc::new(registered)), + }) + } +} + +// convert Array{N} to DecodedMediaData +// TODO: Array1 for audio +impl TryFrom, D>> for DecodedMediaData { + type Error = StorageError; + + fn try_from(array: ArrayBase, D>) -> Result { + let shape = array.shape().to_vec(); + let (data_vec, _) = array.into_raw_vec_and_offset(); + + // Allocate new system storage and copy data + // TODO: use arena allocator and avoid copies + let mut storage = SystemStorage::new(data_vec.len())?; + unsafe { + std::ptr::copy_nonoverlapping(data_vec.as_ptr(), storage.as_mut_ptr(), data_vec.len()); + } + + Ok(Self { + data: storage, + tensor_info: MediaTensorInfo { + shape, + dtype: DataType::UINT8, + metadata: None, + }, + }) + } +} + +// Get NIXL metadata for a descriptor +// Avoids cross-request leak possibility and reduces metadata size +// TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target? +pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result { + // WAR: Until https://github.com/ai-dynamo/nixl/pull/970 is merged, can't use get_local_partial_md + let nixl_md = agent.raw_agent().get_local_md()?; + // let mut reg_desc_list = RegDescList::new(MemType::Dram)?; + // reg_desc_list.add_storage_desc(storage)?; + // let nixl_partial_md = agent.raw_agent().get_local_partial_md(®_desc_list, None)?; + + let b64_encoded = general_purpose::STANDARD.encode(&nixl_md); + Ok(format!("b64:{}", b64_encoded)) +} + +pub fn get_nixl_agent() -> Result { + let name = format!("media-loader-{}", uuid::Uuid::new_v4()); + let nixl_agent = NixlAgent::with_backends(&name, &["UCX"])?; + Ok(nixl_agent) +} diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index 71260da5d2..7d5abf4c50 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -6,12 +6,13 @@ use serde::{Deserialize, Serialize}; use super::{OutputOptions, SamplingOptions, StopConditions}; use crate::kv_router::RouterConfigOverride; +use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::protocols::TokenIdType; #[derive(Serialize, Deserialize, Debug, Clone)] pub enum MultimodalData { Url(url::Url), - // TODO: Decoded(DecodedMediaData), + Decoded(RdmaMediaDataDescriptor), } // multimodal map containing {mm_part_type: [data...]} @@ -31,6 +32,7 @@ pub struct PreprocessedRequest { #[builder(default)] #[serde(default, skip_serializing_if = "Option::is_none")] pub multi_modal_data: Option, + /// StopConditions are conditions that the inference engine will use to stop generation. pub stop_conditions: StopConditions, diff --git a/lib/memory/Cargo.toml b/lib/memory/Cargo.toml index c435b2a278..7b7c513d42 100644 --- a/lib/memory/Cargo.toml +++ b/lib/memory/Cargo.toml @@ -26,7 +26,7 @@ dynamo-config = { workspace = true } anyhow = { workspace = true } cudarc = { workspace = true } -nixl-sys = { version = "0.7" } +nixl-sys = { git = "https://github.com/ai-dynamo/nixl", rev = "ae3f8af" } serde = { workspace = true} thiserror = { workspace = true } tracing = { workspace = true } diff --git a/lib/memory/src/nixl.rs b/lib/memory/src/nixl.rs index d81a338abc..776f176e53 100644 --- a/lib/memory/src/nixl.rs +++ b/lib/memory/src/nixl.rs @@ -14,6 +14,7 @@ pub use agent::NixlAgent; pub use config::NixlBackendConfig; pub use nixl_sys::{MemType, OptArgs, RegistrationHandle}; +pub use serde::{Deserialize, Serialize}; /// Trait for storage types that can be registered with NIXL. pub trait NixlCompatible { @@ -24,7 +25,7 @@ pub trait NixlCompatible { } /// NIXL descriptor containing registration information. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct NixlDescriptor { pub addr: u64, pub size: usize, diff --git a/lib/memory/src/nixl/agent.rs b/lib/memory/src/nixl/agent.rs index a1cf4343fb..f5361674ff 100644 --- a/lib/memory/src/nixl/agent.rs +++ b/lib/memory/src/nixl/agent.rs @@ -32,6 +32,10 @@ pub struct NixlAgent { impl NixlAgent { /// Create a NIXL agent without any backends. pub fn new(name: &str) -> Result { + if nixl_sys::is_stub() { + return Err(anyhow::anyhow!("NIXL is stubbed, cannot create agent")); + } + let agent = Agent::new(name)?; Ok(Self {