Skip to content

Commit c66bf8f

Browse files
authored
Merge pull request #7 from ScrapeGraphAI/fix/nested-arrays-issue-6
Fix: Preserve nested arrays and objects in array encoding (Issue #6)
2 parents c0b29ca + 38eacc0 commit c66bf8f

File tree

2 files changed

+205
-8
lines changed

2 files changed

+205
-8
lines changed

tests/test_nested_arrays.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""Tests for nested arrays and objects within tabular arrays (Issue #6)."""
2+
import pytest
3+
from toon import encode, decode
4+
5+
6+
def test_array_of_objects_with_nested_arrays():
7+
"""Test encoding/decoding arrays of objects that contain nested arrays."""
8+
original = {
9+
"categorization": [
10+
{
11+
"id": "01.04.04.01.",
12+
"label": "Aspetti generali",
13+
"hierarchy": [
14+
"Prodotti",
15+
"Organizzazione altro e Sito Internet",
16+
"Aspetti generali",
17+
"Aspetti generali"
18+
],
19+
"score": 900,
20+
"winner": True,
21+
"namespace": "$namespace",
22+
"frequency": 0,
23+
"offset": [
24+
{"start": 511, "end": 520},
25+
{"start": 524, "end": 527},
26+
{"start": 528, "end": 543}
27+
]
28+
}
29+
]
30+
}
31+
32+
# Encode
33+
toon = encode(original)
34+
print("Encoded TOON:")
35+
print(toon)
36+
37+
# Decode
38+
result = decode(toon)
39+
print("\nDecoded result:")
40+
print(result)
41+
42+
# Verify all fields are preserved
43+
assert result == original, "Decoded data should match original"
44+
assert "hierarchy" in result["categorization"][0], "hierarchy field should be preserved"
45+
assert "offset" in result["categorization"][0], "offset field should be preserved"
46+
assert len(result["categorization"][0]["hierarchy"]) == 4, "hierarchy array should have 4 items"
47+
assert len(result["categorization"][0]["offset"]) == 3, "offset array should have 3 items"
48+
49+
50+
def test_array_of_objects_with_nested_objects():
51+
"""Test encoding/decoding arrays where objects contain nested objects."""
52+
original = {
53+
"users": [
54+
{
55+
"id": 1,
56+
"name": "Alice",
57+
"address": {
58+
"street": "123 Main St",
59+
"city": "NYC"
60+
}
61+
},
62+
{
63+
"id": 2,
64+
"name": "Bob",
65+
"address": {
66+
"street": "456 Oak Ave",
67+
"city": "LA"
68+
}
69+
}
70+
]
71+
}
72+
73+
# Encode
74+
toon = encode(original)
75+
print("Encoded TOON:")
76+
print(toon)
77+
78+
# Decode
79+
result = decode(toon)
80+
81+
# Verify all fields are preserved
82+
assert result == original
83+
assert "address" in result["users"][0]
84+
assert result["users"][0]["address"]["city"] == "NYC"
85+
86+
87+
def test_array_of_objects_mixed_primitive_and_nested():
88+
"""Test arrays with both primitive and nested fields."""
89+
original = {
90+
"items": [
91+
{
92+
"id": 1,
93+
"name": "Item A",
94+
"tags": ["tag1", "tag2"],
95+
"price": 10.5
96+
},
97+
{
98+
"id": 2,
99+
"name": "Item B",
100+
"tags": ["tag3"],
101+
"price": 20.0
102+
}
103+
]
104+
}
105+
106+
# Encode
107+
toon = encode(original)
108+
109+
# Decode
110+
result = decode(toon)
111+
112+
# Verify all fields are preserved
113+
assert result == original
114+
assert "tags" in result["items"][0]
115+
assert len(result["items"][0]["tags"]) == 2
116+
assert len(result["items"][1]["tags"]) == 1
117+
118+
119+
def test_roundtrip_complex_nested_structure():
120+
"""Test full roundtrip of complex nested structure."""
121+
original = {
122+
"data": [
123+
{
124+
"id": "A1",
125+
"value": 100,
126+
"metadata": {
127+
"created": "2024-01-01",
128+
"tags": ["important", "urgent"]
129+
},
130+
"scores": [95, 87, 92]
131+
},
132+
{
133+
"id": "A2",
134+
"value": 200,
135+
"metadata": {
136+
"created": "2024-01-02",
137+
"tags": ["normal"]
138+
},
139+
"scores": [88, 90]
140+
}
141+
]
142+
}
143+
144+
# First roundtrip
145+
toon1 = encode(original)
146+
result1 = decode(toon1)
147+
assert result1 == original
148+
149+
# Second roundtrip
150+
toon2 = encode(result1)
151+
result2 = decode(toon2)
152+
assert result2 == original
153+
assert toon1 == toon2
154+
155+
156+
def test_array_of_objects_some_with_nested_some_without():
157+
"""Test arrays where only some objects have nested fields."""
158+
original = {
159+
"records": [
160+
{
161+
"id": 1,
162+
"name": "Record A",
163+
"extra": {"note": "Has nested"}
164+
},
165+
{
166+
"id": 2,
167+
"name": "Record B"
168+
# No 'extra' field
169+
}
170+
]
171+
}
172+
173+
# Encode
174+
toon = encode(original)
175+
176+
# Decode
177+
result = decode(toon)
178+
179+
# Verify structure is preserved
180+
assert "extra" in result["records"][0]
181+
assert "extra" not in result["records"][1]
182+
assert result["records"][0]["extra"]["note"] == "Has nested"
183+

toon/utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,33 +158,47 @@ def is_array_of_objects(value: Any) -> bool:
158158

159159
def is_uniform_array_of_objects(value: list) -> Optional[list]:
160160
"""
161-
Check if an array contains objects with identical primitive fields.
161+
Check if an array contains objects with identical primitive-only fields.
162+
163+
This function determines if an array of objects can use the compact tabular format.
164+
Tabular format is only used when ALL fields in ALL objects are primitive types.
165+
If any object contains non-primitive fields (arrays, nested objects), the function
166+
returns None, and the encoder will use list array format instead to preserve all data.
162167
163168
Args:
164169
value: Array to check
165170
166171
Returns:
167-
List of field names if uniform, None otherwise
172+
List of field names if uniform and all primitive, None otherwise
168173
"""
169174
if not value or not all(isinstance(item, dict) for item in value):
170175
return None
171176

172-
# Get fields from first object
177+
# Get all fields from first object and check if they're primitive
173178
first_obj = value[0]
174179
fields = []
175180

176181
for key, val in first_obj.items():
177-
if is_primitive(val):
178-
fields.append(key)
182+
if not is_primitive(val):
183+
# Object contains non-primitive field (array or nested object)
184+
# Cannot use tabular format - must use list format to preserve all data
185+
return None
186+
fields.append(key)
179187

180188
if not fields:
181189
return None
182190

183-
# Check all objects have the same primitive fields
191+
# Check all objects have the exact same fields, all primitive
184192
for obj in value[1:]:
185-
obj_fields = [k for k, v in obj.items() if is_primitive(v)]
186-
if set(obj_fields) != set(fields):
193+
# Check that this object has exactly the same fields
194+
if set(obj.keys()) != set(fields):
187195
return None
196+
197+
# Check that all values in this object are primitive
198+
for key, val in obj.items():
199+
if not is_primitive(val):
200+
# Found non-primitive field - cannot use tabular format
201+
return None
188202

189203
return fields
190204

0 commit comments

Comments
 (0)