Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
600 changes: 585 additions & 15 deletions Cargo.lock

Large diffs are not rendered by default.

209 changes: 13 additions & 196 deletions ros-z-codegen/src/generator/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ pub fn generate_message_impl_with_context(
ctx: &GenerationContext,
) -> Result<TokenStream> {
let name = format_ident!("{}", msg.parsed.name);
let has_zbuf = msg.parsed.fields.iter().any(is_zbuf_field);

let struct_def = generate_struct_with_context(
&msg.parsed.package,
Expand All @@ -68,20 +67,14 @@ pub fn generate_message_impl_with_context(
let type_info =
generate_message_type_info(&msg.parsed.package, &msg.parsed.name, &msg.type_hash);

// Generate custom serde for ZBuf-containing messages
let serde_impl = if has_zbuf {
generate_zbuf_serde_with_context(&name, &msg.parsed.fields, &msg.parsed.package, ctx)?
} else {
quote! {}
};
// No longer need custom serde - ros_z::ZBuf implements Serialize/Deserialize

// Generate size estimation implementation
let size_estimation_impl = generate_size_estimation_impl(&name, &msg.parsed.fields, &msg.parsed.package, ctx)?;

Ok(quote! {
#struct_def
#type_info
#serde_impl
#size_estimation_impl
})
}
Expand Down Expand Up @@ -122,9 +115,6 @@ fn generate_struct_with_context(
.iter()
.any(|f| matches!(&f.field_type.array, ArrayType::Fixed(n) if *n > 32));

// Check if we have ZBuf fields (need custom serde)
let has_zbuf = fields.iter().any(is_zbuf_field);

// Generate constants as associated constants
let const_defs: Vec<TokenStream> = constants
.iter()
Expand All @@ -141,22 +131,7 @@ fn generate_struct_with_context(
// Python bridge module path for derive macros
let py_module_path = format!("ros_z_msgs_py.types.{}", package);

if has_zbuf {
// Messages with ZBuf fields need custom serde (no derive)
// But still need Default
Ok(quote! {
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "python_registry", derive(::ros_z_derive::FromPyMessage, ::ros_z_derive::IntoPyMessage))]
#[cfg_attr(feature = "python_registry", ros_msg(module = #py_module_path))]
pub struct #name_ident {
#(#field_defs),*
}

impl #name_ident {
#(#const_defs)*
}
})
} else if has_large_array {
if has_large_array {
// Large array messages need smart-default for arrays >32 elements
Ok(quote! {
#[derive(Debug, Clone, ::smart_default::SmartDefault, ::serde::Serialize, ::serde::Deserialize)]
Expand Down Expand Up @@ -385,9 +360,9 @@ fn generate_field_type_tokens_with_context(
quote! { [#base; #n_lit] }
}
ArrayType::Unbounded => {
// Use ZBuf for uint8[]/byte[] (zero-copy optimization)
// Use ros_z::ZBuf wrapper for uint8[]/byte[] (zero-copy with optimized serde)
if matches!(field_type.base_type.as_str(), "uint8" | "byte") {
quote! { ::zenoh_buffers::ZBuf }
quote! { ::ros_z::ZBuf }
} else {
quote! { ::std::vec::Vec<#base> }
}
Expand Down Expand Up @@ -531,8 +506,12 @@ fn generate_field_size_expr(
// Handle different field types
if is_zbuf_field(field) {
// ZBuf: 4 bytes length prefix + data length
// Note: ros_z::ZBuf derefs to zenoh_buffers::ZBuf, so .len() works via Deref
Ok(quote! {
size += 4 + ::zenoh_buffers::buffer::Buffer::len(&self.#field_name);
size += 4 + {
use ::zenoh_buffers::buffer::Buffer;
self.#field_name.len()
};
})
} else if field.field_type.base_type == "string" {
match &field.field_type.array {
Expand Down Expand Up @@ -644,167 +623,6 @@ fn get_primitive_size(base_type: &str) -> Result<usize> {
})
}

/// Generate custom serde implementation for ZBuf-containing messages
#[allow(dead_code)]
fn generate_zbuf_serde(
name: &proc_macro2::Ident,
fields: &[Field],
source_package: &str,
) -> Result<TokenStream> {
generate_zbuf_serde_with_context(name, fields, source_package, &GenerationContext::default())
}

/// Generate custom serde implementation for ZBuf-containing messages (with external type support)
fn generate_zbuf_serde_with_context(
name: &proc_macro2::Ident,
fields: &[Field],
source_package: &str,
ctx: &GenerationContext,
) -> Result<TokenStream> {
let serialize_fields: Vec<TokenStream> = fields
.iter()
.map(|f| {
let field_name = escape_field_name(&f.name);
let field_name_str = &f.name;
if is_zbuf_field(f) {
// OPTIMIZED: Use a wrapper type that calls serialize_bytes() directly
// instead of serialize_field() which treats &[u8] as a sequence.
// This is critical for performance with large byte arrays!
quote! {
{
use ::zenoh_buffers::buffer::SplitBuffer;
let bytes = self.#field_name.contiguous();
// Wrapper that calls serialize_bytes() directly for efficiency
struct BytesSerializer<'a>(&'a [u8]);
impl<'a> ::serde::Serialize for BytesSerializer<'a> {
fn serialize<S: ::serde::Serializer>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> {
serializer.serialize_bytes(self.0)
}
}
state.serialize_field(#field_name_str, &BytesSerializer(bytes.as_ref()))?;
}
}
} else {
quote! {
state.serialize_field(#field_name_str, &self.#field_name)?;
}
}
})
.collect();

// Generate sequential field deserialization for CDR (positional binary format)
let seq_deserialize_fields: Vec<TokenStream> = fields
.iter()
.enumerate()
.map(|(i, f)| {
let field_name = escape_field_name(&f.name);
let field_type = generate_field_type_tokens_with_context(&f.field_type, source_package, ctx).unwrap();
if is_zbuf_field(f) {
// OPTIMIZED: Use a wrapper type that calls deserialize_bytes() directly
// instead of the default Vec<u8> deserialization which uses deserialize_seq().
// This is critical for performance with large byte arrays!
quote! {
let #field_name: #field_type = {
// Wrapper that uses deserialize_bytes for efficient bulk reading
struct BytesDeserializer;
impl<'de> ::serde::de::DeserializeSeed<'de> for BytesDeserializer {
type Value = ::std::vec::Vec<u8>;
fn deserialize<D: ::serde::Deserializer<'de>>(self, deserializer: D) -> ::std::result::Result<Self::Value, D::Error> {
struct BytesVisitor;
impl<'de> ::serde::de::Visitor<'de> for BytesVisitor {
type Value = ::std::vec::Vec<u8>;
fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
formatter.write_str("byte array")
}
fn visit_bytes<E: ::serde::de::Error>(self, v: &[u8]) -> ::std::result::Result<Self::Value, E> {
Ok(v.to_vec())
}
fn visit_borrowed_bytes<E: ::serde::de::Error>(self, v: &'de [u8]) -> ::std::result::Result<Self::Value, E> {
Ok(v.to_vec())
}
fn visit_byte_buf<E: ::serde::de::Error>(self, v: ::std::vec::Vec<u8>) -> ::std::result::Result<Self::Value, E> {
Ok(v)
}
// Fallback for formats that don't support bytes natively
fn visit_seq<A: ::serde::de::SeqAccess<'de>>(self, mut seq: A) -> ::std::result::Result<Self::Value, A::Error> {
let len = seq.size_hint().unwrap_or(0);
let mut bytes = ::std::vec::Vec::with_capacity(len);
while let Some(b) = seq.next_element()? {
bytes.push(b);
}
Ok(bytes)
}
}
deserializer.deserialize_bytes(BytesVisitor)
}
}
let bytes: ::std::vec::Vec<u8> = seq.next_element_seed(BytesDeserializer)?
.ok_or_else(|| ::serde::de::Error::invalid_length(#i, &self))?;
::zenoh_buffers::ZBuf::from(bytes)
};
}
} else {
quote! {
let #field_name: #field_type = seq.next_element()?
.ok_or_else(|| ::serde::de::Error::invalid_length(#i, &self))?;
}
}
})
.collect();

let field_names: Vec<_> = fields.iter().map(|f| escape_field_name(&f.name)).collect();
let field_name_strs: Vec<_> = fields.iter().map(|f| &f.name).collect();
let num_fields = fields.len();

Ok(quote! {
impl ::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
use ::serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct(stringify!(#name), #num_fields)?;
#(#serialize_fields)*
state.end()
}
}

impl<'de> ::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
use ::serde::de::{SeqAccess, Visitor};
use ::std::fmt;

struct FieldVisitor;

impl<'de> Visitor<'de> for FieldVisitor {
type Value = #name;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(concat!("struct ", stringify!(#name)))
}

fn visit_seq<V>(self, mut seq: V) -> ::std::result::Result<Self::Value, V::Error>
where
V: SeqAccess<'de>,
{
#(#seq_deserialize_fields)*

Ok(#name {
#(#field_names),*
})
}
}

const FIELDS: &[&str] = &[#(#field_name_strs),*];
deserializer.deserialize_struct(stringify!(#name), FIELDS, FieldVisitor)
}
}
})
}

/// Generate service type implementation
pub fn generate_service_impl(srv: &ResolvedService) -> Result<TokenStream> {
let name = format_ident!("{}", srv.parsed.name);
Expand Down Expand Up @@ -1064,11 +882,10 @@ mod tests {
let tokens = result.unwrap();
let code = tokens.to_string();

// Should contain ZBuf field (optimized for zero-copy)
assert!(code.contains("zenoh_buffers :: ZBuf"));
// Should have custom Serialize implementation (not derived)
assert!(code.contains("impl :: serde :: Serialize"));
assert!(code.contains("impl < 'de > :: serde :: Deserialize"));
// Should contain ros_z::ZBuf field (wrapper with optimized serde)
assert!(code.contains("ros_z :: ZBuf"));
// Should use derived Serialize/Deserialize (ZBuf wrapper implements these traits)
assert!(code.contains("derive"));
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions ros-z-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ fn generate_field_extraction(
// Fallback for lists (slow path)
py_attr.extract()?
};
::zenoh_buffers::ZBuf::from(bytes)
::ros_z::ZBuf::from(bytes)
}
});
}
Expand Down Expand Up @@ -288,7 +288,7 @@ fn generate_field_extraction(
// Fallback for lists (slow path)
py_attr.extract()?
};
::zenoh_buffers::ZBuf::from(bytes)
::ros_z::ZBuf::from(bytes)
}
}),
}
Expand Down
3 changes: 2 additions & 1 deletion ros-z-msgs/tests/shm_size_estimation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
use std::sync::Arc;

use ros_z::{
ZBuf,
msg::{ZMessage, ZSerializer},
shm::ShmProviderBuilder,
};
use zenoh_buffers::{ZBuf, buffer::Buffer};
use zenoh_buffers::buffer::Buffer;

#[test]
fn test_pointcloud2_shm_serialization_with_accurate_estimate() {
Expand Down
4 changes: 2 additions & 2 deletions ros-z-msgs/tests/size_estimation_performance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

use std::time::Instant;

use ros_z::msg::ZMessage;
use ros_z::{ZBuf, msg::ZMessage};
use ros_z_msgs::{builtin_interfaces::Time, sensor_msgs::*, std_msgs::Header};
use zenoh_buffers::{ZBuf, buffer::Buffer};
use zenoh_buffers::buffer::Buffer;

#[test]
fn test_pointcloud2_serialization_performance() {
Expand Down
8 changes: 3 additions & 5 deletions ros-z-msgs/tests/zbuf_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
//! serialized, and that the zero-copy optimization works correctly.

use byteorder::LittleEndian;
use ros_z::ZBuf;
use ros_z_cdr::to_vec;
use ros_z_msgs::{sensor_msgs::CompressedImage, std_msgs::Header};
use zenoh_buffers::{
ZBuf,
buffer::{Buffer, SplitBuffer},
};
use zenoh_buffers::buffer::{Buffer, SplitBuffer};

#[test]
fn test_zbuf_field_serialization() {
Expand Down Expand Up @@ -40,7 +38,7 @@ fn test_zbuf_empty() {
let img = CompressedImage {
header: Header::default(),
format: "png".to_string(),
data: ZBuf::empty(),
data: ZBuf::default(),
};

assert_eq!(img.data.len(), 0);
Expand Down
4 changes: 2 additions & 2 deletions ros-z/examples/shm_pointcloud2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use std::{sync::Arc, time::Instant};

use ros_z::{
Builder,
Builder, ZBuf,
context::ZContextBuilder,
shm::{ShmConfig, ShmProviderBuilder},
};
Expand All @@ -29,7 +29,7 @@ use zenoh::{
Wait,
shm::{BlockOn, GarbageCollect, ShmProvider},
};
use zenoh_buffers::{ZBuf, buffer::Buffer};
use zenoh_buffers::buffer::Buffer;

fn main() -> zenoh::Result<()> {
println!("=== PointCloud2 with SHM Example ===\n");
Expand Down
7 changes: 2 additions & 5 deletions ros-z/examples/z_pingpong.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@ use std::{

use clap::Parser;
use csv::Writer;
use ros_z::{Builder, Result, context::ZContextBuilder};
use ros_z::{Builder, Result, ZBuf, context::ZContextBuilder};
use ros_z_msgs::std_msgs::ByteMultiArray;
use zenoh_buffers::{
ZBuf,
buffer::{Buffer, SplitBuffer},
};
use zenoh_buffers::buffer::{Buffer, SplitBuffer};

#[derive(Debug, Parser)]
struct Args {
Expand Down
2 changes: 2 additions & 0 deletions ros-z/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ pub mod ros_msg;
pub mod service;
pub mod shm;
pub mod topic_name;
pub mod zbuf;

#[macro_use]
pub mod utils;

pub use attachment::GidArray;
pub use entity::{TypeHash, TypeInfo};
pub use ros_msg::{ActionTypeInfo, MessageTypeInfo, ServiceTypeInfo, WithTypeInfo};
pub use zbuf::ZBuf;
pub use zenoh::Result;

pub trait Builder {
Expand Down
Loading
Loading