diff --git a/src/event/format/json.rs b/src/event/format/json.rs index 16ff77e36..c0be9320d 100644 --- a/src/event/format/json.rs +++ b/src/event/format/json.rs @@ -44,7 +44,7 @@ impl EventFormat for Event { fn to_data( self, schema: &HashMap>, - static_schema_flag: Option<&String>, + static_schema_flag: bool, time_partition: Option<&String>, schema_version: SchemaVersion, ) -> Result<(Self::Data, Vec>, bool), anyhow::Error> { @@ -94,7 +94,7 @@ impl EventFormat for Event { } }; - if static_schema_flag.is_none() + if !static_schema_flag && value_arr .iter() .any(|value| fields_mismatch(&schema, value, schema_version)) diff --git a/src/event/format/mod.rs b/src/event/format/mod.rs index 040a3eba7..094dbeb1b 100644 --- a/src/event/format/mod.rs +++ b/src/event/format/mod.rs @@ -77,7 +77,7 @@ pub trait EventFormat: Sized { fn to_data( self, schema: &HashMap>, - static_schema_flag: Option<&String>, + static_schema_flag: bool, time_partition: Option<&String>, schema_version: SchemaVersion, ) -> Result<(Self::Data, EventSchema, bool), AnyError>; @@ -87,7 +87,7 @@ pub trait EventFormat: Sized { fn into_recordbatch( self, storage_schema: &HashMap>, - static_schema_flag: Option<&String>, + static_schema_flag: bool, time_partition: Option<&String>, schema_version: SchemaVersion, ) -> Result<(RecordBatch, bool), AnyError> { @@ -130,9 +130,9 @@ pub trait EventFormat: Sized { fn is_schema_matching( new_schema: Arc, storage_schema: &HashMap>, - static_schema_flag: Option<&String>, + static_schema_flag: bool, ) -> bool { - if static_schema_flag.is_none() { + if !static_schema_flag { return true; } for field in new_schema.fields() { diff --git a/src/handlers/http/ingest.rs b/src/handlers/http/ingest.rs index f7955d735..6f3cb3f07 100644 --- a/src/handlers/http/ingest.rs +++ b/src/handlers/http/ingest.rs @@ -90,7 +90,7 @@ pub async fn ingest_internal_stream(stream_name: String, body: Bytes) -> Result< .clone(); let event = format::json::Event { data: body_val }; // For internal streams, use old schema - event.into_recordbatch(&schema, None, None, SchemaVersion::V0)? + event.into_recordbatch(&schema, false, None, SchemaVersion::V0)? }; event::Event { rb, @@ -285,7 +285,7 @@ pub async fn create_stream_if_not_exists( "", None, "", - "", + false, Arc::new(Schema::empty()), stream_type, ) @@ -405,7 +405,7 @@ mod tests { }); let (rb, _) = - into_event_batch(&json, HashMap::default(), None, None, SchemaVersion::V0).unwrap(); + into_event_batch(&json, HashMap::default(), false, None, SchemaVersion::V0).unwrap(); assert_eq!(rb.num_rows(), 1); assert_eq!(rb.num_columns(), 4); @@ -432,7 +432,7 @@ mod tests { }); let (rb, _) = - into_event_batch(&json, HashMap::default(), None, None, SchemaVersion::V0).unwrap(); + into_event_batch(&json, HashMap::default(), false, None, SchemaVersion::V0).unwrap(); assert_eq!(rb.num_rows(), 1); assert_eq!(rb.num_columns(), 3); @@ -462,7 +462,7 @@ mod tests { .into_iter(), ); - let (rb, _) = into_event_batch(&json, schema, None, None, SchemaVersion::V0).unwrap(); + let (rb, _) = into_event_batch(&json, schema, false, None, SchemaVersion::V0).unwrap(); assert_eq!(rb.num_rows(), 1); assert_eq!(rb.num_columns(), 3); @@ -492,7 +492,7 @@ mod tests { .into_iter(), ); - assert!(into_event_batch(&json, schema, None, None, SchemaVersion::V0,).is_err()); + assert!(into_event_batch(&json, schema, false, None, SchemaVersion::V0,).is_err()); } #[test] @@ -508,7 +508,7 @@ mod tests { .into_iter(), ); - let (rb, _) = into_event_batch(&json, schema, None, None, SchemaVersion::V0).unwrap(); + let (rb, _) = into_event_batch(&json, schema, false, None, SchemaVersion::V0).unwrap(); assert_eq!(rb.num_rows(), 1); assert_eq!(rb.num_columns(), 1); @@ -517,6 +517,7 @@ mod tests { #[test] fn non_object_arr_is_err() { let json = json!([1]); + assert!(convert_array_to_object( json, None, @@ -547,7 +548,7 @@ mod tests { ]); let (rb, _) = - into_event_batch(&json, HashMap::default(), None, None, SchemaVersion::V0).unwrap(); + into_event_batch(&json, HashMap::default(), false, None, SchemaVersion::V0).unwrap(); assert_eq!(rb.num_rows(), 3); assert_eq!(rb.num_columns(), 4); @@ -594,7 +595,7 @@ mod tests { ]); let (rb, _) = - into_event_batch(&json, HashMap::default(), None, None, SchemaVersion::V0).unwrap(); + into_event_batch(&json, HashMap::default(), false, None, SchemaVersion::V0).unwrap(); assert_eq!(rb.num_rows(), 3); assert_eq!(rb.num_columns(), 4); @@ -641,7 +642,7 @@ mod tests { .into_iter(), ); - let (rb, _) = into_event_batch(&json, schema, None, None, SchemaVersion::V0).unwrap(); + let (rb, _) = into_event_batch(&json, schema, false, None, SchemaVersion::V0).unwrap(); assert_eq!(rb.num_rows(), 3); assert_eq!(rb.num_columns(), 4); @@ -688,7 +689,7 @@ mod tests { .into_iter(), ); - assert!(into_event_batch(&json, schema, None, None, SchemaVersion::V0,).is_err()); + assert!(into_event_batch(&json, schema, false, None, SchemaVersion::V0,).is_err()); } #[test] @@ -729,7 +730,7 @@ mod tests { let (rb, _) = into_event_batch( &flattened_json, HashMap::default(), - None, + false, None, SchemaVersion::V0, ) @@ -817,7 +818,7 @@ mod tests { let (rb, _) = into_event_batch( &flattened_json, HashMap::default(), - None, + false, None, SchemaVersion::V1, ) diff --git a/src/handlers/http/logstream.rs b/src/handlers/http/logstream.rs index 6457d1ae2..cafe56190 100644 --- a/src/handlers/http/logstream.rs +++ b/src/handlers/http/logstream.rs @@ -489,7 +489,7 @@ pub async fn create_stream( time_partition: &str, time_partition_limit: Option, custom_partition: &str, - static_schema_flag: &str, + static_schema_flag: bool, schema: Arc, stream_type: &str, ) -> Result<(), CreateStreamError> { @@ -529,7 +529,7 @@ pub async fn create_stream( time_partition.to_string(), time_partition_limit, custom_partition.to_string(), - static_schema_flag.to_string(), + static_schema_flag, static_schema, stream_type, SchemaVersion::V1, // New stream @@ -582,7 +582,7 @@ pub async fn get_stream_info(req: HttpRequest) -> Result>, - static_schema_flag: Option<&String>, + static_schema_flag: bool, time_partition: Option<&String>, schema_version: SchemaVersion, ) -> Result<(arrow_array::RecordBatch, bool), PostError> { diff --git a/src/handlers/http/modal/utils/logstream_utils.rs b/src/handlers/http/modal/utils/logstream_utils.rs index d3fdd46eb..3543198d0 100644 --- a/src/handlers/http/modal/utils/logstream_utils.rs +++ b/src/handlers/http/modal/utils/logstream_utils.rs @@ -50,7 +50,7 @@ pub async fn create_update_stream( stream_type, ) = fetch_headers_from_put_stream_request(req); - if metadata::STREAM_INFO.stream_exists(stream_name) && update_stream_flag != "true" { + if metadata::STREAM_INFO.stream_exists(stream_name) && !update_stream_flag { return Err(StreamError::Custom { msg: format!( "Logstream {stream_name} already exists, please create a new log stream with unique name" @@ -71,12 +71,12 @@ pub async fn create_update_stream( }); } - if update_stream_flag == "true" { + if update_stream_flag { return update_stream( req, stream_name, &time_partition, - &static_schema_flag, + static_schema_flag, &time_partition_limit, &custom_partition, ) @@ -102,7 +102,7 @@ pub async fn create_update_stream( stream_name, &time_partition, &custom_partition, - &static_schema_flag, + static_schema_flag, )?; create_stream( @@ -110,7 +110,7 @@ pub async fn create_update_stream( &time_partition, time_partition_in_days, &custom_partition, - &static_schema_flag, + static_schema_flag, schema, &stream_type, ) @@ -123,7 +123,7 @@ async fn update_stream( req: &HttpRequest, stream_name: &str, time_partition: &str, - static_schema_flag: &str, + static_schema_flag: bool, time_partition_limit: &str, custom_partition: &str, ) -> Result { @@ -136,7 +136,7 @@ async fn update_stream( status: StatusCode::BAD_REQUEST, }); } - if !static_schema_flag.is_empty() { + if static_schema_flag { return Err(StreamError::Custom { msg: "Altering the schema of an existing stream is restricted.".to_string(), status: StatusCode::BAD_REQUEST, @@ -167,12 +167,12 @@ async fn validate_and_update_custom_partition( pub fn fetch_headers_from_put_stream_request( req: &HttpRequest, -) -> (String, String, String, String, String, String) { +) -> (String, String, String, bool, bool, String) { let mut time_partition = String::default(); let mut time_partition_limit = String::default(); let mut custom_partition = String::default(); - let mut static_schema_flag = String::default(); - let mut update_stream = String::default(); + let mut static_schema_flag = false; + let mut update_stream_flag = false; let mut stream_type = StreamType::UserDefined.to_string(); req.headers().iter().for_each(|(key, value)| { if key == TIME_PARTITION_KEY { @@ -184,11 +184,11 @@ pub fn fetch_headers_from_put_stream_request( if key == CUSTOM_PARTITION_KEY { custom_partition = value.to_str().unwrap().to_string(); } - if key == STATIC_SCHEMA_FLAG { - static_schema_flag = value.to_str().unwrap().to_string(); + if key == STATIC_SCHEMA_FLAG && value.to_str().unwrap() == "true" { + static_schema_flag = true; } - if key == UPDATE_STREAM_KEY { - update_stream = value.to_str().unwrap().to_string(); + if key == UPDATE_STREAM_KEY && value.to_str().unwrap() == "true" { + update_stream_flag = true; } if key == STREAM_TYPE_KEY { stream_type = value.to_str().unwrap().to_string(); @@ -200,7 +200,7 @@ pub fn fetch_headers_from_put_stream_request( time_partition_limit, custom_partition, static_schema_flag, - update_stream, + update_stream_flag, stream_type, ) } @@ -258,9 +258,9 @@ pub fn validate_static_schema( stream_name: &str, time_partition: &str, custom_partition: &str, - static_schema_flag: &str, + static_schema_flag: bool, ) -> Result, CreateStreamError> { - if static_schema_flag == "true" { + if static_schema_flag { if body.is_empty() { return Err(CreateStreamError::Custom { msg: format!( @@ -317,7 +317,7 @@ pub async fn update_custom_partition_in_stream( ) -> Result<(), CreateStreamError> { let static_schema_flag = STREAM_INFO.get_static_schema_flag(&stream_name).unwrap(); let time_partition = STREAM_INFO.get_time_partition(&stream_name).unwrap(); - if static_schema_flag.is_some() { + if static_schema_flag { let schema = STREAM_INFO.schema(&stream_name).unwrap(); if !custom_partition.is_empty() { @@ -383,7 +383,7 @@ pub async fn create_stream( time_partition: &str, time_partition_limit: Option, custom_partition: &str, - static_schema_flag: &str, + static_schema_flag: bool, schema: Arc, stream_type: &str, ) -> Result<(), CreateStreamError> { @@ -423,7 +423,7 @@ pub async fn create_stream( time_partition.to_string(), time_partition_limit, custom_partition.to_string(), - static_schema_flag.to_string(), + static_schema_flag, static_schema, stream_type, SchemaVersion::V1, // New stream @@ -473,7 +473,7 @@ pub async fn create_stream_and_schema_from_storage(stream_name: &str) -> Result< .time_partition_limit .and_then(|limit| limit.parse().ok()); let custom_partition = stream_metadata.custom_partition.as_deref().unwrap_or(""); - let static_schema_flag = stream_metadata.static_schema_flag.as_deref().unwrap_or(""); + let static_schema_flag = stream_metadata.static_schema_flag; let stream_type = stream_metadata.stream_type.as_deref().unwrap_or(""); let schema_version = stream_metadata.schema_version; @@ -483,7 +483,7 @@ pub async fn create_stream_and_schema_from_storage(stream_name: &str) -> Result< time_partition.to_string(), time_partition_limit, custom_partition.to_string(), - static_schema_flag.to_string(), + static_schema_flag, static_schema, stream_type, schema_version, diff --git a/src/kafka.rs b/src/kafka.rs index ebfbeb5ba..45c7d5220 100644 --- a/src/kafka.rs +++ b/src/kafka.rs @@ -194,7 +194,7 @@ async fn ingest_message(msg: BorrowedMessage<'_>) -> Result<(), KafkaError> { let (rb, is_first) = event .into_recordbatch( &schema, - static_schema_flag.as_ref(), + static_schema_flag, time_partition.as_ref(), schema_version, ) diff --git a/src/metadata.rs b/src/metadata.rs index 017ddab49..94954c4c0 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -72,7 +72,7 @@ pub struct LogStreamMetadata { pub time_partition: Option, pub time_partition_limit: Option, pub custom_partition: Option, - pub static_schema_flag: Option, + pub static_schema_flag: bool, pub hot_tier_enabled: Option, pub stream_type: Option, } @@ -144,14 +144,11 @@ impl StreamInfo { .map(|metadata| metadata.custom_partition.clone()) } - pub fn get_static_schema_flag( - &self, - stream_name: &str, - ) -> Result, MetadataError> { + pub fn get_static_schema_flag(&self, stream_name: &str) -> Result { let map = self.read().expect(LOCK_EXPECT); map.get(stream_name) .ok_or(MetadataError::StreamMetaNotFound(stream_name.to_string())) - .map(|metadata| metadata.static_schema_flag.clone()) + .map(|metadata| metadata.static_schema_flag) } pub fn get_retention(&self, stream_name: &str) -> Result, MetadataError> { @@ -270,7 +267,7 @@ impl StreamInfo { time_partition: String, time_partition_limit: Option, custom_partition: String, - static_schema_flag: String, + static_schema_flag: bool, static_schema: HashMap>, stream_type: &str, schema_version: SchemaVersion, @@ -293,11 +290,7 @@ impl StreamInfo { } else { Some(custom_partition) }, - static_schema_flag: if static_schema_flag != "true" { - None - } else { - Some(static_schema_flag) - }, + static_schema_flag, schema: if static_schema.is_empty() { HashMap::new() } else { diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 9c6ef9a5c..573800812 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -20,9 +20,11 @@ use crate::{ catalog::snapshot::Snapshot, metadata::{error::stream_info::MetadataError, SchemaVersion}, stats::FullStats, + utils::json::{deserialize_string_as_true, serialize_bool_as_true}, }; use chrono::Local; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; @@ -76,7 +78,7 @@ const ACCESS_ALL: &str = "all"; pub const CURRENT_OBJECT_STORE_VERSION: &str = "v5"; pub const CURRENT_SCHEMA_VERSION: &str = "v5"; -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ObjectStoreFormat { /// Version of schema registry pub version: String, @@ -104,14 +106,19 @@ pub struct ObjectStoreFormat { pub time_partition_limit: Option, #[serde(skip_serializing_if = "Option::is_none")] pub custom_partition: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub static_schema_flag: Option, + #[serde( + default, // sets to false if not configured + deserialize_with = "deserialize_string_as_true", + serialize_with = "serialize_bool_as_true", + skip_serializing_if = "std::ops::Not::not" + )] + pub static_schema_flag: bool, #[serde(skip_serializing_if = "Option::is_none")] pub hot_tier_enabled: Option, pub stream_type: Option, } -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct StreamInfo { #[serde(rename = "created-at")] pub created_at: String, @@ -124,8 +131,13 @@ pub struct StreamInfo { pub time_partition_limit: Option, #[serde(skip_serializing_if = "Option::is_none")] pub custom_partition: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub static_schema_flag: Option, + #[serde( + default, // sets to false if not configured + deserialize_with = "deserialize_string_as_true", + serialize_with = "serialize_bool_as_true", + skip_serializing_if = "std::ops::Not::not" + )] + pub static_schema_flag: bool, pub stream_type: Option, } @@ -191,7 +203,7 @@ impl Default for ObjectStoreFormat { time_partition: None, time_partition_limit: None, custom_partition: None, - static_schema_flag: None, + static_schema_flag: false, hot_tier_enabled: None, } } diff --git a/src/storage/object_storage.rs b/src/storage/object_storage.rs index 272ab47bb..9c07f8068 100644 --- a/src/storage/object_storage.rs +++ b/src/storage/object_storage.rs @@ -151,7 +151,7 @@ pub trait ObjectStorage: Send + Sync + 'static { time_partition: &str, time_partition_limit: Option, custom_partition: &str, - static_schema_flag: &str, + static_schema_flag: bool, schema: Arc, stream_type: &str, ) -> Result { @@ -162,8 +162,7 @@ pub trait ObjectStorage: Send + Sync + 'static { time_partition: (!time_partition.is_empty()).then(|| time_partition.to_string()), time_partition_limit: time_partition_limit.map(|limit| limit.to_string()), custom_partition: (!custom_partition.is_empty()).then(|| custom_partition.to_string()), - static_schema_flag: (static_schema_flag == "true") - .then(|| static_schema_flag.to_string()), + static_schema_flag, schema_version: SchemaVersion::V1, // NOTE: Newly created streams are all V1 owner: Owner { id: CONFIG.parseable.username.clone(), @@ -560,7 +559,7 @@ pub trait ObjectStorage: Send + Sync + 'static { let static_schema_flag = STREAM_INFO .get_static_schema_flag(stream) .map_err(|err| ObjectStorageError::UnhandledError(Box::new(err)))?; - if static_schema_flag.is_none() { + if !static_schema_flag { commit_schema_to_storage(stream, schema).await?; } } diff --git a/src/utils/json/mod.rs b/src/utils/json/mod.rs index 4c794afba..0f5c05812 100644 --- a/src/utils/json/mod.rs +++ b/src/utils/json/mod.rs @@ -16,9 +16,11 @@ * */ +use std::fmt; use std::num::NonZeroU32; use flatten::{convert_to_array, generic_flattening, has_more_than_four_levels}; +use serde::de::Visitor; use serde_json; use serde_json::Value; @@ -107,11 +109,65 @@ pub fn convert_to_string(value: &Value) -> Value { } } +struct TrueFromStr; + +impl Visitor<'_> for TrueFromStr { + type Value = bool; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string containing \"true\"") + } + + fn visit_borrowed_str(self, v: &'_ str) -> Result + where + E: serde::de::Error, + { + self.visit_str(v) + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + match s { + "true" => Ok(true), + other => Err(E::custom(format!( + r#"Expected value: "true", got: {}"#, + other + ))), + } + } +} + +/// Used to convert "true" to boolean true and everything else is failed. +/// This is necessary because the default deserializer for bool in serde is not +/// able to handle the value "true", which we have previously written to config. +pub fn deserialize_string_as_true<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + deserializer.deserialize_str(TrueFromStr) +} + +/// Used to convert boolean true to "true" and everything else is skipped. +pub fn serialize_bool_as_true(value: &bool, serializer: S) -> Result +where + S: serde::Serializer, +{ + if *value { + serializer.serialize_str("true") + } else { + // Skip serializing this field + serializer.serialize_none() + } +} + #[cfg(test)] mod tests { use crate::event::format::LogSource; - use super::flatten_json_body; + use super::*; + use serde::{Deserialize, Serialize}; use serde_json::json; #[test] @@ -151,4 +207,95 @@ mod tests { expected ); } + + #[derive(Serialize, Deserialize)] + struct TestBool { + #[serde( + default, + deserialize_with = "deserialize_string_as_true", + serialize_with = "serialize_bool_as_true", + skip_serializing_if = "std::ops::Not::not" + )] + value: bool, + other_field: String, + } + + #[test] + fn deserialize_true() { + let json = r#"{"value": "true", "other_field": "test"}"#; + let test_bool: TestBool = serde_json::from_str(json).unwrap(); + assert!(test_bool.value); + } + + #[test] + fn deserialize_none_as_false() { + let json = r#"{"other_field": "test"}"#; + let test_bool: TestBool = serde_json::from_str(json).unwrap(); + assert!(!test_bool.value); + } + + #[test] + fn fail_to_deserialize_invalid_value_including_false_or_raw_bool() { + let json = r#"{"value": "false", "other_field": "test"}"#; + assert!(serde_json::from_str::(json).is_err()); + + let json = r#"{"value": true, "other_field": "test"}"#; + assert!(serde_json::from_str::(json).is_err()); + + let json = r#"{"value": false, "other_field": "test"}"#; + assert!(serde_json::from_str::(json).is_err()); + + let json = r#"{"value": "invalid", "other_field": "test"}"#; + assert!(serde_json::from_str::(json).is_err()); + + let json = r#"{"value": 123}"#; + assert!(serde_json::from_str::(json).is_err()); + + let json = r#"{"value": null}"#; + assert!(serde_json::from_str::(json).is_err()); + } + + #[test] + fn serialize_true_value() { + let test_bool = TestBool { + value: true, + other_field: "test".to_string(), + }; + let json = serde_json::to_string(&test_bool).unwrap(); + assert_eq!(json, r#"{"value":"true","other_field":"test"}"#); + } + + #[test] + fn serialize_false_value_skips_field() { + let test_bool = TestBool { + value: false, + other_field: "test".to_string(), + }; + let json = serde_json::to_string(&test_bool).unwrap(); + assert_eq!(json, r#"{"other_field":"test"}"#); + } + + #[test] + fn roundtrip_true() { + let original = TestBool { + value: true, + other_field: "test".to_string(), + }; + let json = serde_json::to_string(&original).unwrap(); + let deserialized: TestBool = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.value, original.value); + assert_eq!(deserialized.other_field, original.other_field); + } + + #[test] + fn roundtrip_false() { + let original = TestBool { + value: false, + other_field: "test".to_string(), + }; + let json = serde_json::to_string(&original).unwrap(); + let deserialized: TestBool = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.value, original.value); + assert_eq!(deserialized.other_field, original.other_field); + } }