diff --git a/docs/docs/ops/targets.md b/docs/docs/ops/targets.md index ba511961f..bb783f95a 100644 --- a/docs/docs/ops/targets.md +++ b/docs/docs/ops/targets.md @@ -39,6 +39,13 @@ For all other vector types, we map them to `jsonb` columns. ::: +:::info U+0000 (NUL) characters in strings + +U+0000 (NUL) is a valid character in Unicode, but Postgres has a limitation that strings (including `text`-like types and strings in `jsonb`) cannot contain them. +CocoIndex automatically strips U+0000 (NUL) characters from strings before exporting to Postgres. For example, if you have a string `"Hello\0World"`, it will be exported as `"HelloWorld"`. + +::: + #### Spec The spec takes the following fields: diff --git a/src/ops/targets/postgres.rs b/src/ops/targets/postgres.rs index 692045a92..741e062b7 100644 --- a/src/ops/targets/postgres.rs +++ b/src/ops/targets/postgres.rs @@ -43,7 +43,7 @@ fn bind_value_field<'arg>( builder.push_bind(&**v); } BasicValue::Str(v) => { - builder.push_bind(&**v); + builder.push_bind(utils::str_sanitize::ZeroCodeStrippedEncode(v.as_ref())); } BasicValue::Bool(v) => { builder.push_bind(v); @@ -82,7 +82,9 @@ fn bind_value_field<'arg>( builder.push_bind(v); } BasicValue::Json(v) => { - builder.push_bind(sqlx::types::Json(&**v)); + builder.push_bind(sqlx::types::Json( + utils::str_sanitize::ZeroCodeStrippedSerialize(&**v), + )); } BasicValue::Vector(v) => match &field_schema.value_type.typ { ValueType::Basic(BasicValueType::Vector(vs)) if convertible_to_pgvector(vs) => { @@ -104,20 +106,24 @@ fn bind_value_field<'arg>( } }, BasicValue::UnionVariant { .. } => { - builder.push_bind(sqlx::types::Json(TypedValue { - t: &field_schema.value_type.typ, - v: value, - })); + builder.push_bind(sqlx::types::Json( + utils::str_sanitize::ZeroCodeStrippedSerialize(TypedValue { + t: &field_schema.value_type.typ, + v: value, + }), + )); } }, Value::Null => { builder.push("NULL"); } v => { - builder.push_bind(sqlx::types::Json(TypedValue { - t: &field_schema.value_type.typ, - v, - })); + builder.push_bind(sqlx::types::Json( + utils::str_sanitize::ZeroCodeStrippedSerialize(TypedValue { + t: &field_schema.value_type.typ, + v, + }), + )); } }; Ok(()) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a13e05b8c..41f8de9e1 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -3,4 +3,5 @@ pub mod db; pub mod fingerprint; pub mod immutable; pub mod retryable; +pub mod str_sanitize; pub mod yaml_ser; diff --git a/src/utils/str_sanitize.rs b/src/utils/str_sanitize.rs new file mode 100644 index 000000000..17b483e13 --- /dev/null +++ b/src/utils/str_sanitize.rs @@ -0,0 +1,597 @@ +use std::borrow::Cow; +use std::fmt::Display; + +use serde::Serialize; +use serde::ser::{ + SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, + SerializeTupleStruct, SerializeTupleVariant, +}; +use sqlx::Type; +use sqlx::encode::{Encode, IsNull}; +use sqlx::error::BoxDynError; +use sqlx::postgres::{PgArgumentBuffer, Postgres}; + +pub fn strip_zero_code<'a>(s: Cow<'a, str>) -> Cow<'a, str> { + if s.contains('\0') { + let mut sanitized = String::with_capacity(s.len()); + for ch in s.chars() { + if ch != '\0' { + sanitized.push(ch); + } + } + Cow::Owned(sanitized) + } else { + s + } +} + +/// A thin wrapper for sqlx parameter binding that strips NUL (\0) bytes +/// from the wrapped string before encoding. +/// +/// Usage: wrap a string reference when binding: +/// `query.bind(ZeroCodeStrippedEncode(my_str))` +#[derive(Copy, Clone, Debug)] +pub struct ZeroCodeStrippedEncode<'a>(pub &'a str); + +impl<'a> Type for ZeroCodeStrippedEncode<'a> { + fn type_info() -> ::TypeInfo { + <&'a str as Type>::type_info() + } + + fn compatible(ty: &::TypeInfo) -> bool { + <&'a str as Type>::compatible(ty) + } +} + +impl<'a> Encode<'a, Postgres> for ZeroCodeStrippedEncode<'a> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + let sanitized = strip_zero_code(Cow::Borrowed(self.0)); + <&str as Encode<'a, Postgres>>::encode_by_ref(&sanitized.as_ref(), buf) + } + + fn size_hint(&self) -> usize { + self.0.len() + } +} + +/// A wrapper that sanitizes zero bytes from strings during serialization. +/// +/// It ensures: +/// - All string values have zero bytes removed +/// - Struct field names are sanitized before being written +/// - Map keys and any nested content are sanitized recursively +pub struct ZeroCodeStrippedSerialize(pub T); + +impl Serialize for ZeroCodeStrippedSerialize +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let sanitizing = SanitizingSerializer { inner: serializer }; + self.0.serialize(sanitizing) + } +} + +/// Internal serializer wrapper that strips zero bytes from strings and sanitizes +/// struct field names by routing struct serialization through maps with sanitized keys. +struct SanitizingSerializer { + inner: S, +} + +// Helper newtype to apply sanitizing serializer to any &T during nested serialization +struct SanitizeRef<'a, T: ?Sized>(&'a T); + +impl<'a, T> Serialize for SanitizeRef<'a, T> +where + T: ?Sized + Serialize, +{ + fn serialize( + &self, + serializer: S1, + ) -> Result<::Ok, ::Error> + where + S1: serde::Serializer, + { + let sanitizing = SanitizingSerializer { inner: serializer }; + self.0.serialize(sanitizing) + } +} + +// Seq wrapper to sanitize nested elements +struct SanitizingSerializeSeq { + inner: S::SerializeSeq, +} + +impl SerializeSeq for SanitizingSerializeSeq +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.inner.serialize_element(&SanitizeRef(value)) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +// Tuple wrapper +struct SanitizingSerializeTuple { + inner: S::SerializeTuple, +} + +impl SerializeTuple for SanitizingSerializeTuple +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.inner.serialize_element(&SanitizeRef(value)) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +// Tuple struct wrapper +struct SanitizingSerializeTupleStruct { + inner: S::SerializeTupleStruct, +} + +impl SerializeTupleStruct for SanitizingSerializeTupleStruct +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.inner.serialize_field(&SanitizeRef(value)) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +// Tuple variant wrapper +struct SanitizingSerializeTupleVariant { + inner: S::SerializeTupleVariant, +} + +impl SerializeTupleVariant for SanitizingSerializeTupleVariant +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.inner.serialize_field(&SanitizeRef(value)) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +// Map wrapper; ensures keys and values are sanitized +struct SanitizingSerializeMap { + inner: S::SerializeMap, +} + +impl SerializeMap for SanitizingSerializeMap +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.inner.serialize_key(&SanitizeRef(key)) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.inner.serialize_value(&SanitizeRef(value)) + } + + fn serialize_entry(&mut self, key: &K, value: &V) -> Result<(), Self::Error> + where + K: ?Sized + Serialize, + V: ?Sized + Serialize, + { + self.inner + .serialize_entry(&SanitizeRef(key), &SanitizeRef(value)) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +// Struct wrapper: implement via inner map to allow dynamic, sanitized field names +struct SanitizingSerializeStruct { + inner: S::SerializeMap, +} + +impl SerializeStruct for SanitizingSerializeStruct +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.inner + .serialize_entry(&SanitizeRef(&key), &SanitizeRef(value)) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +impl serde::Serializer for SanitizingSerializer +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + type SerializeSeq = SanitizingSerializeSeq; + type SerializeTuple = SanitizingSerializeTuple; + type SerializeTupleStruct = SanitizingSerializeTupleStruct; + type SerializeTupleVariant = SanitizingSerializeTupleVariant; + type SerializeMap = SanitizingSerializeMap; + type SerializeStruct = SanitizingSerializeStruct; + type SerializeStructVariant = SanitizingSerializeStructVariant; + + fn serialize_bool(self, v: bool) -> Result { + self.inner.serialize_bool(v) + } + + fn serialize_i8(self, v: i8) -> Result { + self.inner.serialize_i8(v) + } + + fn serialize_i16(self, v: i16) -> Result { + self.inner.serialize_i16(v) + } + + fn serialize_i32(self, v: i32) -> Result { + self.inner.serialize_i32(v) + } + + fn serialize_i64(self, v: i64) -> Result { + self.inner.serialize_i64(v) + } + + fn serialize_u8(self, v: u8) -> Result { + self.inner.serialize_u8(v) + } + + fn serialize_u16(self, v: u16) -> Result { + self.inner.serialize_u16(v) + } + + fn serialize_u32(self, v: u32) -> Result { + self.inner.serialize_u32(v) + } + + fn serialize_u64(self, v: u64) -> Result { + self.inner.serialize_u64(v) + } + + fn serialize_f32(self, v: f32) -> Result { + self.inner.serialize_f32(v) + } + + fn serialize_f64(self, v: f64) -> Result { + self.inner.serialize_f64(v) + } + + fn serialize_char(self, v: char) -> Result { + // A single char cannot contain a NUL; forward directly + self.inner.serialize_char(v) + } + + fn serialize_str(self, v: &str) -> Result { + let sanitized = strip_zero_code(Cow::Borrowed(v)); + self.inner.serialize_str(sanitized.as_ref()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + self.inner.serialize_bytes(v) + } + + fn serialize_none(self) -> Result { + self.inner.serialize_none() + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + self.inner.serialize_some(&SanitizeRef(value)) + } + + fn serialize_unit(self) -> Result { + self.inner.serialize_unit() + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + // Type names are not field names; forward + self.inner.serialize_unit_struct(name) + } + + fn serialize_unit_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result { + // Variant names are not field names; forward + self.inner + .serialize_unit_variant(name, variant_index, variant) + } + + fn serialize_newtype_struct( + self, + name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.inner + .serialize_newtype_struct(name, &SanitizeRef(value)) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.inner + .serialize_newtype_variant(name, variant_index, variant, &SanitizeRef(value)) + } + + fn serialize_seq(self, len: Option) -> Result { + Ok(SanitizingSerializeSeq { + inner: self.inner.serialize_seq(len)?, + }) + } + + fn serialize_tuple(self, len: usize) -> Result { + Ok(SanitizingSerializeTuple { + inner: self.inner.serialize_tuple(len)?, + }) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + Ok(SanitizingSerializeTupleStruct { + inner: self.inner.serialize_tuple_struct(name, len)?, + }) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(SanitizingSerializeTupleVariant { + inner: self + .inner + .serialize_tuple_variant(name, variant_index, variant, len)?, + }) + } + + fn serialize_map(self, len: Option) -> Result { + Ok(SanitizingSerializeMap { + inner: self.inner.serialize_map(len)?, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + // Route through a map so we can provide dynamically sanitized field names + Ok(SanitizingSerializeStruct { + inner: self.inner.serialize_map(Some(len))?, + }) + } + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(SanitizingSerializeStructVariant { + inner: self + .inner + .serialize_struct_variant(name, variant_index, variant, len)?, + }) + } + + fn is_human_readable(&self) -> bool { + self.inner.is_human_readable() + } + + fn collect_str(self, value: &T) -> Result + where + T: ?Sized + Display, + { + let s = value.to_string(); + let sanitized = strip_zero_code(Cow::Owned(s)); + self.inner.serialize_str(sanitized.as_ref()) + } +} + +// Struct variant wrapper: sanitize field names and nested values +struct SanitizingSerializeStructVariant { + inner: S::SerializeStructVariant, +} + +impl SerializeStructVariant for SanitizingSerializeStructVariant +where + S: serde::Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + // Cannot allocate dynamic field names here due to &'static str bound. + // Sanitize only values. + self.inner.serialize_field(key, &SanitizeRef(value)) + } + + fn end(self) -> Result { + self.inner.end() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::Serialize; + use serde_json::{Value, json}; + use std::borrow::Cow; + use std::collections::BTreeMap; + + #[test] + fn strip_zero_code_no_change_borrowed() { + let input = "abc"; + let out = strip_zero_code(Cow::Borrowed(input)); + assert!(matches!(out, Cow::Borrowed(_))); + assert_eq!(out.as_ref(), "abc"); + } + + #[test] + fn strip_zero_code_removes_nuls_owned() { + let input = "a\0b\0c\0".to_string(); + let out = strip_zero_code(Cow::Owned(input)); + assert_eq!(out.as_ref(), "abc"); + } + + #[test] + fn wrapper_sanitizes_plain_string_value() { + let s = "he\0ll\0o"; + let v: Value = serde_json::to_value(ZeroCodeStrippedSerialize(s)).unwrap(); + assert_eq!(v, json!("hello")); + } + + #[test] + fn wrapper_sanitizes_map_keys_and_values() { + let mut m = BTreeMap::new(); + m.insert("a\0b".to_string(), "x\0y".to_string()); + m.insert("\0start".to_string(), "en\0d".to_string()); + let v: Value = serde_json::to_value(ZeroCodeStrippedSerialize(&m)).unwrap(); + let obj = v.as_object().unwrap(); + assert_eq!(obj.get("ab").unwrap(), &json!("xy")); + assert_eq!(obj.get("start").unwrap(), &json!("end")); + assert!(!obj.contains_key("a\0b")); + assert!(!obj.contains_key("\0start")); + } + + #[derive(Serialize)] + struct TestStruct { + #[serde(rename = "fi\0eld")] // Intentionally includes NUL + value: String, + #[serde(rename = "n\0ested")] // Intentionally includes NUL + nested: Inner, + } + + #[derive(Serialize)] + struct Inner { + #[serde(rename = "n\0ame")] // Intentionally includes NUL + name: String, + } + + #[test] + fn wrapper_sanitizes_struct_field_names_and_values() { + let s = TestStruct { + value: "hi\0!".to_string(), + nested: Inner { + name: "al\0ice".to_string(), + }, + }; + let v: Value = serde_json::to_value(ZeroCodeStrippedSerialize(&s)).unwrap(); + let obj = v.as_object().unwrap(); + assert!(obj.contains_key("field")); + assert!(obj.contains_key("nested")); + assert_eq!(obj.get("field").unwrap(), &json!("hi!")); + let nested = obj.get("nested").unwrap().as_object().unwrap(); + assert!(nested.contains_key("name")); + assert_eq!(nested.get("name").unwrap(), &json!("alice")); + assert!(!obj.contains_key("fi\0eld")); + } + + #[derive(Serialize)] + enum TestEnum { + Var { + #[serde(rename = "ke\0y")] // Intentionally includes NUL + field: String, + }, + } + + #[test] + fn wrapper_sanitizes_struct_variant_values_only() { + let e = TestEnum::Var { + field: "b\0ar".to_string(), + }; + let v: Value = serde_json::to_value(ZeroCodeStrippedSerialize(&e)).unwrap(); + // {"Var":{"key":"bar"}} + let root = v.as_object().unwrap(); + let var = root.get("Var").unwrap().as_object().unwrap(); + // Field name remains unchanged due to &'static str constraint of SerializeStructVariant + assert!(var.contains_key("ke\0y")); + assert_eq!(var.get("ke\0y").unwrap(), &json!("bar")); + } +}