Skip to content

Commit db4c5ac

Browse files
fixup: capture array extras and mismatches on deserialize_with_unknowns_derive
1 parent 6ff0cb4 commit db4c5ac

File tree

2 files changed

+168
-24
lines changed
  • lightning-liquidity/src/lsps0
  • lightning-macros/src

2 files changed

+168
-24
lines changed

lightning-liquidity/src/lsps0/ser.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,3 +1086,98 @@ pub(crate) mod u32_fee_rate {
10861086
Ok(FeeRate::from_sat_per_kwu(fee_rate_sat_kwu as u64))
10871087
}
10881088
}
1089+
1090+
#[cfg(test)]
1091+
mod tests {
1092+
use super::DeserializeWithUnknowns;
1093+
use serde::ser::SerializeSeq;
1094+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
1095+
1096+
#[derive(Debug, Clone, PartialEq, Eq)]
1097+
struct TruncVec<T>(Vec<T>);
1098+
1099+
impl<T> TruncVec<T> {
1100+
fn new(v: Vec<T>) -> Self {
1101+
Self(v)
1102+
}
1103+
}
1104+
1105+
impl<T: Serialize> Serialize for TruncVec<T> {
1106+
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
1107+
let n = if self.0.is_empty() { 0 } else { 1 };
1108+
let mut seq = serializer.serialize_seq(Some(n))?;
1109+
if n == 1 {
1110+
seq.serialize_element(&self.0[0])?;
1111+
}
1112+
seq.end()
1113+
}
1114+
}
1115+
1116+
impl<'de, T: Deserialize<'de>> Deserialize<'de> for TruncVec<T> {
1117+
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
1118+
let v = Vec::<T>::deserialize(deserializer)?;
1119+
Ok(Self(v))
1120+
}
1121+
}
1122+
1123+
#[derive(Debug, Clone, PartialEq, Eq)]
1124+
struct AsStringAny;
1125+
1126+
impl Serialize for AsStringAny {
1127+
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
1128+
serializer.serialize_str("constant")
1129+
}
1130+
}
1131+
1132+
impl<'de> Deserialize<'de> for AsStringAny {
1133+
fn deserialize<D: Deserializer<'de>>(_deserializer: D) -> Result<Self, D::Error> {
1134+
Ok(AsStringAny)
1135+
}
1136+
}
1137+
1138+
#[derive(
1139+
crate::prelude::DeserializeWithUnknowns, Deserialize, Serialize, Debug, PartialEq, Eq,
1140+
)]
1141+
struct TopArray(TruncVec<u8>);
1142+
1143+
#[derive(
1144+
crate::prelude::DeserializeWithUnknowns, Deserialize, Serialize, Debug, PartialEq, Eq,
1145+
)]
1146+
struct NestedArray {
1147+
list: TruncVec<u8>,
1148+
}
1149+
1150+
#[derive(
1151+
crate::prelude::DeserializeWithUnknowns, Deserialize, Serialize, Debug, PartialEq, Eq,
1152+
)]
1153+
struct MismatchNested {
1154+
key: AsStringAny,
1155+
}
1156+
1157+
#[test]
1158+
fn array_extras_top_level_are_reported() {
1159+
let input = serde_json::json!([1, 2, 3]);
1160+
let (val, unknown) =
1161+
<TopArray as DeserializeWithUnknowns>::deserialize_with_unknowns(input).unwrap();
1162+
assert_eq!(val, TopArray(TruncVec::new(vec![1, 2, 3])));
1163+
assert_eq!(unknown, vec!("[1]", "[2]"));
1164+
}
1165+
1166+
#[test]
1167+
fn array_extras_nested_are_reported() {
1168+
let input = serde_json::json!({"list": [10, 20, 30]});
1169+
let (val, unknown) =
1170+
<NestedArray as DeserializeWithUnknowns>::deserialize_with_unknowns(input).unwrap();
1171+
assert_eq!(val, NestedArray { list: TruncVec::new(vec![10, 20, 30]) });
1172+
assert_eq!(unknown, vec!("list[1]", "list[2]"));
1173+
}
1174+
1175+
#[test]
1176+
fn object_vs_nonobject_mismatch_reports_children() {
1177+
let input = serde_json::json!({"key": {"unknown": 1, "also_unknown": "x"}});
1178+
let (_val, mut unknown) =
1179+
<MismatchNested as DeserializeWithUnknowns>::deserialize_with_unknowns(input).unwrap();
1180+
unknown.sort();
1181+
assert_eq!(unknown, vec!("key.also_unknown", "key.unknown"));
1182+
}
1183+
}

lightning-macros/src/lib.rs

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,45 @@ pub fn deserialize_with_unknowns_derive(input: TokenStream) -> TokenStream {
445445
use ::alloc::string::ToString;
446446
if base_path.is_empty() { ["[", &index.to_string(), "]"].concat() } else { [base_path, "[", &index.to_string(), "]"].concat() }
447447
}
448+
// Records unknowns for mismatched shapes by enumerating immediate children.
449+
fn record_mismatch_unknowns(
450+
input_json: &::serde_json::Value,
451+
base_path: &str,
452+
out_paths: &mut ::alloc::vec::Vec<::alloc::string::String>,
453+
) {
454+
use ::serde_json::Value::*;
455+
match input_json {
456+
Object(map) => {
457+
for (k, v) in map.iter() {
458+
if !v.is_null() {
459+
out_paths.push(join_path(base_path, k));
460+
}
461+
}
462+
},
463+
Array(arr) => {
464+
for (idx, v) in arr.iter().enumerate() {
465+
if !v.is_null() {
466+
out_paths.push(index_path(base_path, idx));
467+
}
468+
}
469+
},
470+
_ => { /* No-op */ }
471+
}
472+
}
473+
474+
// Marks extra non-null input array elements beyond recognized length as unknown.
475+
fn record_extra_array_items(
476+
input_array: &::alloc::vec::Vec<::serde_json::Value>,
477+
recognized_len: usize,
478+
base_path: &str,
479+
out_paths: &mut ::alloc::vec::Vec<::alloc::string::String>,
480+
) {
481+
for idx in recognized_len..input_array.len() {
482+
if !input_array[idx].is_null() {
483+
out_paths.push(index_path(base_path, idx));
484+
}
485+
}
486+
}
448487
// Recursively walk input vs recognized JSON and collect unknown field paths.
449488
fn collect_unknown_paths(
450489
input_json: &::serde_json::Value,
@@ -463,29 +502,30 @@ pub fn deserialize_with_unknowns_derive(input: TokenStream) -> TokenStream {
463502
out_paths.push(join_path(base_path, key));
464503
}
465504
},
466-
// Key known. Recurse based on the value shape.
467-
Some(recognized_value) => match (input_value, recognized_value) {
468-
(Object(_), Object(_)) => collect_unknown_paths(
469-
input_value,
470-
recognized_value,
471-
&join_path(base_path, key),
472-
out_paths,
473-
),
474-
(Array(input_array), Array(recognized_array)) => {
475-
for (idx, (input_elem, recognized_elem)) in input_array
476-
.iter()
477-
.zip(recognized_array.iter())
478-
.enumerate()
479-
{
480-
collect_unknown_paths(
481-
input_elem,
482-
recognized_elem,
483-
&index_path(&join_path(base_path, key), idx),
484-
out_paths,
485-
);
486-
}
487-
},
488-
_ => { /* No-op */ }
505+
// Key known. Recurse or record mismatch.
506+
Some(recognized_value) => {
507+
let child_base = join_path(base_path, key);
508+
match (input_value, recognized_value) {
509+
(Object(_), Object(_)) => collect_unknown_paths(
510+
input_value, recognized_value, &child_base, out_paths,
511+
),
512+
(Array(input_array), Array(recognized_array)) => {
513+
for (idx, (input_elem, recognized_elem)) in input_array
514+
.iter()
515+
.zip(recognized_array.iter())
516+
.enumerate()
517+
{
518+
collect_unknown_paths(
519+
input_elem,
520+
recognized_elem,
521+
&index_path(&child_base, idx),
522+
out_paths,
523+
);
524+
}
525+
record_extra_array_items(input_array, recognized_array.len(), &child_base, out_paths);
526+
},
527+
_ => record_mismatch_unknowns(input_value, &child_base, out_paths),
528+
}
489529
}
490530
}
491531
}
@@ -503,8 +543,17 @@ pub fn deserialize_with_unknowns_derive(input: TokenStream) -> TokenStream {
503543
out_paths,
504544
);
505545
}
546+
// If the input array is longer than the recognized array,
547+
// mark the extra elements as unknown.
548+
if input_array.len() > recognized_array.len() {
549+
for idx in recognized_array.len()..input_array.len() {
550+
if !input_array[idx].is_null() {
551+
out_paths.push(index_path(base_path, idx));
552+
}
553+
}
554+
}
506555
},
507-
_ => { /* No-op */ }
556+
_ => record_mismatch_unknowns(input_json, base_path, out_paths),
508557
}
509558
}
510559

0 commit comments

Comments
 (0)