1+ import dataclasses
2+ import hashlib
13import io
24import json
3- from typing import Any , List , Tuple
4-
5- import PIL
6- import PIL .Image
5+ import logging
6+ import PIL .GifImagePlugin
77import mlcroissant as mlc
8+ import PIL .Image
89import pandas as pd
910
11+ from typing import Any , List , Tuple
1012
1113from aperturedb .Subscriptable import Subscriptable
1214from aperturedb .Query import QueryBuilder
13- from aperturedb .CommonLibrary import execute_query
15+ from aperturedb .DataModels import IdentityDataModel
16+ from aperturedb .Query import generate_add_query
1417
1518
16- import dataclasses
17- import hashlib
19+ logger = logging .getLogger (__name__ )
1820
19- from aperturedb .DataModels import IdentityDataModel
20- from aperturedb .Query import generate_add_query
2121
2222MAX_REF_VALUE = 99999
23+ # This is useful to identify the class of the record in ApertureDB.
24+ CLASS_PROPERTY_NAME = "adb_class_name"
2325
2426
2527class RecordSetModel (IdentityDataModel ):
@@ -29,6 +31,7 @@ class RecordSetModel(IdentityDataModel):
2931
3032
3133class DatasetModel (IdentityDataModel ):
34+ url : str = ""
3235 name : str = "Croissant Dataset automatically ingested into ApertureDB"
3336 description : str = f"A dataset loaded from a croissant json-ld"
3437 version : str = "1.0.0"
@@ -54,10 +57,20 @@ def deserialize_record(record):
5457 if record == pd .NaT :
5558 deserialized = "Not Available Time"
5659 if isinstance (deserialized , str ):
57- try :
58- deserialized = json .loads (deserialized )
59- except :
60- pass
60+ if deserialized .startswith ("[" ) or deserialized .startswith ("{" ):
61+ # If it looks like a list or dict, try to parse it as JSON
62+ try :
63+ deserialized = json .loads (deserialized )
64+ except json .JSONDecodeError :
65+ logger .info (f"Failed to parse JSON: { deserialized } " )
66+
67+ try :
68+ deserialized = json .loads (deserialized .replace ("'" , "\" " ))
69+ except Exception as e :
70+ logger .info (
71+ f"Failed to parse JSON: { deserialized } with error { e } " )
72+ pass
73+
6174 if isinstance (deserialized , list ):
6275 deserialized = [deserialize_record (item ) for item in deserialized ]
6376 if isinstance (deserialized , dict ):
@@ -67,9 +80,12 @@ def deserialize_record(record):
6780 return deserialized
6881
6982
70- def persist_metadata (dataset : mlc .Dataset ) -> Tuple [List [dict ], List [bytes ]]:
71-
83+ def persist_metadata (dataset : mlc .Dataset , url : str ) -> Tuple [List [dict ], List [bytes ]]:
84+ """
85+ Persist the metadata of a croissant dataset into ApertureDB.
86+ """
7287 ds = DatasetModel (
88+ url = url ,
7389 name = dataset .metadata .name ,
7490 description = dataset .metadata .description ,
7591 version = dataset .metadata .version or "1.0.0" ,
@@ -84,38 +100,83 @@ def persist_metadata(dataset: mlc.Dataset) -> Tuple[List[dict], List[bytes]]:
84100 return q , b
85101
86102
103+ def try_parse (value : str ) -> Any :
104+ """Attempts to parse a string value into a more appropriate type."""
105+ parsed = value .strip ()
106+
107+ if parsed .startswith ("http" ):
108+ # Download the content from the URL
109+ from aperturedb .Sources import Sources
110+ sources = Sources (n_download_retries = 3 )
111+ result , buffer = sources .load_from_http_url (
112+ parsed , validator = lambda x : True )
113+ if result :
114+ parsed = PIL .Image .open (io .BytesIO (buffer ))
115+
116+ return parsed
117+
118+
87119def dict_to_query (row_dict , name : str , flatten_json : bool ) -> Any :
88120 literals = {}
89121 subitems = {}
90- blobs = {}
91- o_literalse = {}
122+ known_image_blobs = {}
123+ unknown_blobs = {}
124+ o_literals = {}
92125
93- # If name is not specified, or begins with _, this enures that it
126+ name = name .split ("/" )[- 1 ] # Use the last part of the name
127+ # If name is not specified, or begins with _, this ensures that it
94128 # complies with the ApertureDB naming conventions
95- name = f"E_{ name or 'Record' } "
129+ if not name or name .startswith ("_" ):
130+ safe_name = f"E_{ name or 'Record' } " # Uncomment if you want
131+ logger .warning (
132+ f"Entity Name '{ name } ' is not valid. Using { safe_name } ." )
133+ name = safe_name
96134
97135 for k , v in row_dict .items ():
98- k = f"F_{ k } "
136+ k = k .split ("/" )[- 1 ] # Use the last part of the key
137+ if not k or k .startswith ("_" ):
138+ safe_key = f"F_{ k or 'Field' } "
139+ logger .warning (
140+ f"Property name '{ k } ' is not valid. Using { safe_key } ." )
141+ k = safe_key
99142 item = v
143+ # Pre processed items from croissant.
100144 if isinstance (item , PIL .Image .Image ):
101145 buffer = io .BytesIO ()
102146 item .save (buffer , format = item .format )
103- blobs [k ] = buffer .getvalue ()
147+ known_image_blobs [k ] = buffer .getvalue ()
104148 continue
105149
106150 record = deserialize_record (item )
151+ if isinstance (record , str ):
152+ record = try_parse (record )
153+
154+ # Post processed items from SDK.
155+ if isinstance (record , PIL .GifImagePlugin .GifImageFile ):
156+ buffer = io .BytesIO ()
157+ record .save (buffer , format = record .format )
158+ unknown_blobs [k ] = buffer .getvalue ()
159+ continue
160+
161+ if isinstance (record , PIL .Image .Image ):
162+ buffer = io .BytesIO ()
163+ record .save (buffer , format = record .format )
164+ known_image_blobs [k ] = buffer .getvalue ()
165+ continue
166+
107167 if flatten_json and isinstance (record , list ):
108168 subitems [k ] = record
109169 else :
110170 literals [k ] = record
111- o_literalse [k ] = item
171+ # Original value from croissant. This is useful for debugging.
172+ o_literals [k ] = item
112173
113174 if flatten_json :
114175 str_rep = "" .join ([f"{ str (k )} { str (v )} " for k , v in literals .items ()])
115176 literals ["adb_uuid" ] = hashlib .sha256 (
116177 str_rep .encode ('utf-8' )).hexdigest ()
117178
118- literals ["adb_class_name" ] = name
179+ literals [CLASS_PROPERTY_NAME ] = name
119180 q = QueryBuilder .add_command (name , {
120181 "properties" : literals ,
121182 "connect" : {
@@ -130,34 +191,48 @@ def dict_to_query(row_dict, name: str, flatten_json: bool) -> Any:
130191 }
131192
132193 dependents = []
133- if len (subitems ) > 0 or len (blobs ) > 0 :
194+ if len (subitems ) > 0 or len (known_image_blobs ) > 0 or len (unknown_blobs ) > 0 :
195+ # We need to create a reference to this record
134196 q [list (q .keys ())[- 1 ]]["_ref" ] = 1
135197
136198 for key in subitems :
137199 for item in subitems [key ]:
138- subitem_query = dict_to_query (item , f"{ name } .{ key } " , flatten_json )
200+ subitem_query , blobs = dict_to_query (
201+ item , f"{ name } .{ key } " , flatten_json )
139202 subitem_query [0 ][list (subitem_query [0 ].keys ())[- 1 ]]["connect" ] = {
140203 "ref" : 1 ,
141204 "class" : key ,
142- "direction" : "out " ,
205+ "direction" : "in " ,
143206 }
144207 dependents .extend (subitem_query )
145208
146209 from aperturedb .Query import ObjectType
147- image_blobs = []
148- for blob in blobs :
210+ blobs = []
211+ for blob in known_image_blobs :
149212 image_query = QueryBuilder .add_command (ObjectType .IMAGE , {
150- "properties" : literals ,
213+ "properties" : { CLASS_PROPERTY_NAME : literals [ CLASS_PROPERTY_NAME ] + "." + "image" } ,
151214 "connect" : {
152215 "ref" : 1 ,
153216 "class" : blob ,
154- "direction" : "out "
217+ "direction" : "in "
155218 }
156219 })
157- image_blobs .append (blobs [blob ])
220+ blobs .append (known_image_blobs [blob ])
158221 dependents .append (image_query )
159222
160- return [q ] + dependents , image_blobs
223+ for blob in unknown_blobs :
224+ blob_query = QueryBuilder .add_command (ObjectType .BLOB , {
225+ "properties" : {CLASS_PROPERTY_NAME : literals [CLASS_PROPERTY_NAME ] + "." + "blob" },
226+ "connect" : {
227+ "ref" : 1 ,
228+ "class" : blob ,
229+ "direction" : "in"
230+ }
231+ })
232+ blobs .append (unknown_blobs [blob ])
233+ dependents .append (blob_query )
234+
235+ return [q ] + dependents , blobs
161236
162237
163238class MLCroissantRecordSet (Subscriptable ):
@@ -178,16 +253,14 @@ def __init__(
178253 if count == sample_count :
179254 break
180255
181- self .df = pd . json_normalize ( samples )
256+ self .samples = samples
182257 self .sample_count = len (samples )
183258 self .name = name
184259 self .flatten_json = flatten_json
185260 self .indexed_entities = set ()
186261
187262 def getitem (self , subscript ):
188- row = self .df .iloc [subscript ]
189- # Convert the row to a dictionary
190- row_dict = row .to_dict ()
263+ row_dict = self .samples [subscript ]
191264
192265 find_recordset_query = QueryBuilder .find_command (
193266 "RecordSetModel" , {
@@ -201,7 +274,7 @@ def getitem(self, subscript):
201274 indexes_to_create = []
202275 for command in q :
203276 cmd = list (command .keys ())[- 1 ]
204- if cmd == "AddImage" :
277+ if cmd in [ "AddImage" , "AddBlob" , "AddVideo" ] :
205278 continue
206279 indexable_entity = command [list (command .keys ())[- 1 ]]["class" ]
207280 if indexable_entity not in self .indexed_entities :
@@ -216,4 +289,4 @@ def getitem(self, subscript):
216289 return indexes_to_create + [find_recordset_query ] + q , blobs
217290
218291 def __len__ (self ):
219- return len (self .df )
292+ return len (self .samples )
0 commit comments