Skip to content

Commit 28ef0e1

Browse files
authored
Croissant data enrichment (#595)
1 parent db2f8d9 commit 28ef0e1

File tree

2 files changed

+112
-39
lines changed

2 files changed

+112
-39
lines changed

aperturedb/MLCroissant.py

Lines changed: 111 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
1+
import dataclasses
2+
import hashlib
13
import io
24
import json
3-
from typing import Any, List, Tuple
4-
5-
import PIL
6-
import PIL.Image
5+
import logging
6+
import PIL.GifImagePlugin
77
import mlcroissant as mlc
8+
import PIL.Image
89
import pandas as pd
910

11+
from typing import Any, List, Tuple
1012

1113
from aperturedb.Subscriptable import Subscriptable
1214
from aperturedb.Query import QueryBuilder
13-
from aperturedb.CommonLibrary import execute_query
15+
from aperturedb.DataModels import IdentityDataModel
16+
from aperturedb.Query import generate_add_query
1417

1518

16-
import dataclasses
17-
import hashlib
19+
logger = logging.getLogger(__name__)
1820

19-
from aperturedb.DataModels import IdentityDataModel
20-
from aperturedb.Query import generate_add_query
2121

2222
MAX_REF_VALUE = 99999
23+
# This is useful to identify the class of the record in ApertureDB.
24+
CLASS_PROPERTY_NAME = "adb_class_name"
2325

2426

2527
class RecordSetModel(IdentityDataModel):
@@ -29,6 +31,7 @@ class RecordSetModel(IdentityDataModel):
2931

3032

3133
class DatasetModel(IdentityDataModel):
34+
url: str = ""
3235
name: str = "Croissant Dataset automatically ingested into ApertureDB"
3336
description: str = f"A dataset loaded from a croissant json-ld"
3437
version: str = "1.0.0"
@@ -54,10 +57,20 @@ def deserialize_record(record):
5457
if record == pd.NaT:
5558
deserialized = "Not Available Time"
5659
if isinstance(deserialized, str):
57-
try:
58-
deserialized = json.loads(deserialized)
59-
except:
60-
pass
60+
if deserialized.startswith("[") or deserialized.startswith("{"):
61+
# If it looks like a list or dict, try to parse it as JSON
62+
try:
63+
deserialized = json.loads(deserialized)
64+
except json.JSONDecodeError:
65+
logger.info(f"Failed to parse JSON: {deserialized}")
66+
67+
try:
68+
deserialized = json.loads(deserialized.replace("'", "\""))
69+
except Exception as e:
70+
logger.info(
71+
f"Failed to parse JSON: {deserialized} with error {e}")
72+
pass
73+
6174
if isinstance(deserialized, list):
6275
deserialized = [deserialize_record(item) for item in deserialized]
6376
if isinstance(deserialized, dict):
@@ -67,9 +80,12 @@ def deserialize_record(record):
6780
return deserialized
6881

6982

70-
def persist_metadata(dataset: mlc.Dataset) -> Tuple[List[dict], List[bytes]]:
71-
83+
def persist_metadata(dataset: mlc.Dataset, url: str) -> Tuple[List[dict], List[bytes]]:
84+
"""
85+
Persist the metadata of a croissant dataset into ApertureDB.
86+
"""
7287
ds = DatasetModel(
88+
url=url,
7389
name=dataset.metadata.name,
7490
description=dataset.metadata.description,
7591
version=dataset.metadata.version or "1.0.0",
@@ -84,38 +100,83 @@ def persist_metadata(dataset: mlc.Dataset) -> Tuple[List[dict], List[bytes]]:
84100
return q, b
85101

86102

103+
def try_parse(value: str) -> Any:
104+
"""Attempts to parse a string value into a more appropriate type."""
105+
parsed = value.strip()
106+
107+
if parsed.startswith("http"):
108+
# Download the content from the URL
109+
from aperturedb.Sources import Sources
110+
sources = Sources(n_download_retries=3)
111+
result, buffer = sources.load_from_http_url(
112+
parsed, validator=lambda x: True)
113+
if result:
114+
parsed = PIL.Image.open(io.BytesIO(buffer))
115+
116+
return parsed
117+
118+
87119
def dict_to_query(row_dict, name: str, flatten_json: bool) -> Any:
88120
literals = {}
89121
subitems = {}
90-
blobs = {}
91-
o_literalse = {}
122+
known_image_blobs = {}
123+
unknown_blobs = {}
124+
o_literals = {}
92125

93-
# If name is not specified, or begins with _, this enures that it
126+
name = name.split("/")[-1] # Use the last part of the name
127+
# If name is not specified, or begins with _, this ensures that it
94128
# complies with the ApertureDB naming conventions
95-
name = f"E_{name or 'Record'}"
129+
if not name or name.startswith("_"):
130+
safe_name = f"E_{name or 'Record'}" # Uncomment if you want
131+
logger.warning(
132+
f"Entity Name '{name}' is not valid. Using {safe_name}.")
133+
name = safe_name
96134

97135
for k, v in row_dict.items():
98-
k = f"F_{k}"
136+
k = k.split("/")[-1] # Use the last part of the key
137+
if not k or k.startswith("_"):
138+
safe_key = f"F_{k or 'Field'}"
139+
logger.warning(
140+
f"Property name '{k}' is not valid. Using {safe_key}.")
141+
k = safe_key
99142
item = v
143+
# Pre processed items from croissant.
100144
if isinstance(item, PIL.Image.Image):
101145
buffer = io.BytesIO()
102146
item.save(buffer, format=item.format)
103-
blobs[k] = buffer.getvalue()
147+
known_image_blobs[k] = buffer.getvalue()
104148
continue
105149

106150
record = deserialize_record(item)
151+
if isinstance(record, str):
152+
record = try_parse(record)
153+
154+
# Post processed items from SDK.
155+
if isinstance(record, PIL.GifImagePlugin.GifImageFile):
156+
buffer = io.BytesIO()
157+
record.save(buffer, format=record.format)
158+
unknown_blobs[k] = buffer.getvalue()
159+
continue
160+
161+
if isinstance(record, PIL.Image.Image):
162+
buffer = io.BytesIO()
163+
record.save(buffer, format=record.format)
164+
known_image_blobs[k] = buffer.getvalue()
165+
continue
166+
107167
if flatten_json and isinstance(record, list):
108168
subitems[k] = record
109169
else:
110170
literals[k] = record
111-
o_literalse[k] = item
171+
# Original value from croissant. This is useful for debugging.
172+
o_literals[k] = item
112173

113174
if flatten_json:
114175
str_rep = "".join([f"{str(k)}{str(v)}" for k, v in literals.items()])
115176
literals["adb_uuid"] = hashlib.sha256(
116177
str_rep.encode('utf-8')).hexdigest()
117178

118-
literals["adb_class_name"] = name
179+
literals[CLASS_PROPERTY_NAME] = name
119180
q = QueryBuilder.add_command(name, {
120181
"properties": literals,
121182
"connect": {
@@ -130,34 +191,48 @@ def dict_to_query(row_dict, name: str, flatten_json: bool) -> Any:
130191
}
131192

132193
dependents = []
133-
if len(subitems) > 0 or len(blobs) > 0:
194+
if len(subitems) > 0 or len(known_image_blobs) > 0 or len(unknown_blobs) > 0:
195+
# We need to create a reference to this record
134196
q[list(q.keys())[-1]]["_ref"] = 1
135197

136198
for key in subitems:
137199
for item in subitems[key]:
138-
subitem_query = dict_to_query(item, f"{name}.{key}", flatten_json)
200+
subitem_query, blobs = dict_to_query(
201+
item, f"{name}.{key}", flatten_json)
139202
subitem_query[0][list(subitem_query[0].keys())[-1]]["connect"] = {
140203
"ref": 1,
141204
"class": key,
142-
"direction": "out",
205+
"direction": "in",
143206
}
144207
dependents.extend(subitem_query)
145208

146209
from aperturedb.Query import ObjectType
147-
image_blobs = []
148-
for blob in blobs:
210+
blobs = []
211+
for blob in known_image_blobs:
149212
image_query = QueryBuilder.add_command(ObjectType.IMAGE, {
150-
"properties": literals,
213+
"properties": {CLASS_PROPERTY_NAME: literals[CLASS_PROPERTY_NAME] + "." + "image"},
151214
"connect": {
152215
"ref": 1,
153216
"class": blob,
154-
"direction": "out"
217+
"direction": "in"
155218
}
156219
})
157-
image_blobs.append(blobs[blob])
220+
blobs.append(known_image_blobs[blob])
158221
dependents.append(image_query)
159222

160-
return [q] + dependents, image_blobs
223+
for blob in unknown_blobs:
224+
blob_query = QueryBuilder.add_command(ObjectType.BLOB, {
225+
"properties": {CLASS_PROPERTY_NAME: literals[CLASS_PROPERTY_NAME] + "." + "blob"},
226+
"connect": {
227+
"ref": 1,
228+
"class": blob,
229+
"direction": "in"
230+
}
231+
})
232+
blobs.append(unknown_blobs[blob])
233+
dependents.append(blob_query)
234+
235+
return [q] + dependents, blobs
161236

162237

163238
class MLCroissantRecordSet(Subscriptable):
@@ -178,16 +253,14 @@ def __init__(
178253
if count == sample_count:
179254
break
180255

181-
self.df = pd.json_normalize(samples)
256+
self.samples = samples
182257
self.sample_count = len(samples)
183258
self.name = name
184259
self.flatten_json = flatten_json
185260
self.indexed_entities = set()
186261

187262
def getitem(self, subscript):
188-
row = self.df.iloc[subscript]
189-
# Convert the row to a dictionary
190-
row_dict = row.to_dict()
263+
row_dict = self.samples[subscript]
191264

192265
find_recordset_query = QueryBuilder.find_command(
193266
"RecordSetModel", {
@@ -201,7 +274,7 @@ def getitem(self, subscript):
201274
indexes_to_create = []
202275
for command in q:
203276
cmd = list(command.keys())[-1]
204-
if cmd == "AddImage":
277+
if cmd in ["AddImage", "AddBlob", "AddVideo"]:
205278
continue
206279
indexable_entity = command[list(command.keys())[-1]]["class"]
207280
if indexable_entity not in self.indexed_entities:
@@ -216,4 +289,4 @@ def getitem(self, subscript):
216289
return indexes_to_create + [find_recordset_query] + q, blobs
217290

218291
def __len__(self):
219-
return len(self.df)
292+
return len(self.samples)

aperturedb/cli/ingest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def from_croissant(
266266
from aperturedb.MLCroissant import MLCroissantRecordSet, persist_metadata
267267

268268
croissant_dataset = mlc.Dataset(url)
269-
metadata = persist_metadata(croissant_dataset)
269+
metadata = persist_metadata(croissant_dataset, url)
270270
_process_data(
271271
[metadata],
272272
sample_count=1,

0 commit comments

Comments
 (0)