Skip to content

Commit db2f8d9

Browse files
authored
Adds a class that can iterate over Croissant datasets and generate queries to persist them to ApertureDB. (#594)
1 parent a58fe1e commit db2f8d9

File tree

8 files changed

+342
-35
lines changed

8 files changed

+342
-35
lines changed

aperturedb/CommonLibrary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def execute_query(client: Connector, query: Commands,
315315
warn_list.append(wr)
316316
if len(warn_list) != 0:
317317
logger.warning(
318-
f"Partial errors:\r\n{json.dumps(query)}\r\n{json.dumps(warn_list)}")
318+
f"Partial errors:\r\n{json.dumps(query, default=str)}\r\n{json.dumps(warn_list, default=str)}")
319319
result = 2
320320

321321
return result, r, b

aperturedb/Connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _query(self, query, blob_array = [], try_resume=True):
376376
response_blob_array = []
377377
# Check the query type
378378
if not isinstance(query, str): # assumes json
379-
query_str = json.dumps(query)
379+
query_str = json.dumps(query, default=str)
380380
else:
381381
query_str = query
382382

aperturedb/ConnectorRest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def __init__(self, host="localhost", port=None,
7777
# A Convenience feature to not require the port
7878
# Relies on common ports for http and https, but can be overriden
7979
if port is None:
80-
self.port = 443 if use_ssl else 80
80+
if config is None:
81+
self.port = 443 if use_ssl else 80
82+
else:
83+
self.port = config.port
8184
else:
8285
self.port = port
8386

aperturedb/EntityUpdateDataCSV.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class SingleEntityUpdateDataCSV(CSVParser.CSVParser):
2525
- a series of updateif_ to determine if an update is necessary
2626
2727
Conditionals:
28-
updateif>_prop - updates if the database value > csv value
29-
updateif<_prop - updates if the database value < csv value
30-
updateif!_prop - updates if the database value is != csv value
28+
```updateif>_prop``` : updates if the database value greater than csv value
29+
```updateif<_prop``` : updates if the database value less than csv value
30+
```updateif!_prop``` : updates if the database value is not equal to csv value
3131
3232
:::note
3333
Is backed by a CSV file with the following columns (format optional):

aperturedb/MLCroissant.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import io
2+
import json
3+
from typing import Any, List, Tuple
4+
5+
import PIL
6+
import PIL.Image
7+
import mlcroissant as mlc
8+
import pandas as pd
9+
10+
11+
from aperturedb.Subscriptable import Subscriptable
12+
from aperturedb.Query import QueryBuilder
13+
from aperturedb.CommonLibrary import execute_query
14+
15+
16+
import dataclasses
17+
import hashlib
18+
19+
from aperturedb.DataModels import IdentityDataModel
20+
from aperturedb.Query import generate_add_query
21+
22+
MAX_REF_VALUE = 99999
23+
24+
25+
class RecordSetModel(IdentityDataModel):
26+
name: str
27+
description: str = ""
28+
uuid: str = ""
29+
30+
31+
class DatasetModel(IdentityDataModel):
32+
name: str = "Croissant Dataset automatically ingested into ApertureDB"
33+
description: str = f"A dataset loaded from a croissant json-ld"
34+
version: str = "1.0.0"
35+
record_sets: List[RecordSetModel] = dataclasses.field(default_factory=list)
36+
37+
38+
def deserialize_record(record):
39+
"""These are the types of records that we expect to deserialize, from croissant Records.
40+
41+
Args:
42+
record (_type_): _description_
43+
44+
Returns:
45+
_type_: _description_
46+
"""
47+
deserialized = record
48+
if record is None:
49+
deserialized = "Not Available"
50+
if isinstance(record, bytes):
51+
deserialized = record.decode('utf-8')
52+
if isinstance(record, pd.Timestamp):
53+
deserialized = {"_date": record.to_pydatetime().isoformat()}
54+
if record == pd.NaT:
55+
deserialized = "Not Available Time"
56+
if isinstance(deserialized, str):
57+
try:
58+
deserialized = json.loads(deserialized)
59+
except:
60+
pass
61+
if isinstance(deserialized, list):
62+
deserialized = [deserialize_record(item) for item in deserialized]
63+
if isinstance(deserialized, dict):
64+
deserialized = {k: deserialize_record(
65+
v) for k, v in deserialized.items()}
66+
67+
return deserialized
68+
69+
70+
def persist_metadata(dataset: mlc.Dataset) -> Tuple[List[dict], List[bytes]]:
71+
72+
ds = DatasetModel(
73+
name=dataset.metadata.name,
74+
description=dataset.metadata.description,
75+
version=dataset.metadata.version or "1.0.0",
76+
record_sets=[RecordSetModel(
77+
name=rs.name,
78+
description=rs.description,
79+
uuid=rs.uuid,
80+
) for rs in dataset.metadata.record_sets]
81+
)
82+
q, b, _ = generate_add_query(ds)
83+
84+
return q, b
85+
86+
87+
def dict_to_query(row_dict, name: str, flatten_json: bool) -> Any:
88+
literals = {}
89+
subitems = {}
90+
blobs = {}
91+
o_literalse = {}
92+
93+
# If name is not specified, or begins with _, this enures that it
94+
# complies with the ApertureDB naming conventions
95+
name = f"E_{name or 'Record'}"
96+
97+
for k, v in row_dict.items():
98+
k = f"F_{k}"
99+
item = v
100+
if isinstance(item, PIL.Image.Image):
101+
buffer = io.BytesIO()
102+
item.save(buffer, format=item.format)
103+
blobs[k] = buffer.getvalue()
104+
continue
105+
106+
record = deserialize_record(item)
107+
if flatten_json and isinstance(record, list):
108+
subitems[k] = record
109+
else:
110+
literals[k] = record
111+
o_literalse[k] = item
112+
113+
if flatten_json:
114+
str_rep = "".join([f"{str(k)}{str(v)}" for k, v in literals.items()])
115+
literals["adb_uuid"] = hashlib.sha256(
116+
str_rep.encode('utf-8')).hexdigest()
117+
118+
literals["adb_class_name"] = name
119+
q = QueryBuilder.add_command(name, {
120+
"properties": literals,
121+
"connect": {
122+
"ref": MAX_REF_VALUE,
123+
"class": "hasRecord",
124+
"direction": "in",
125+
}
126+
})
127+
if flatten_json:
128+
q[list(q.keys())[-1]]["if_not_found"] = {
129+
"adb_uuid": ["==", literals["adb_uuid"]]
130+
}
131+
132+
dependents = []
133+
if len(subitems) > 0 or len(blobs) > 0:
134+
q[list(q.keys())[-1]]["_ref"] = 1
135+
136+
for key in subitems:
137+
for item in subitems[key]:
138+
subitem_query = dict_to_query(item, f"{name}.{key}", flatten_json)
139+
subitem_query[0][list(subitem_query[0].keys())[-1]]["connect"] = {
140+
"ref": 1,
141+
"class": key,
142+
"direction": "out",
143+
}
144+
dependents.extend(subitem_query)
145+
146+
from aperturedb.Query import ObjectType
147+
image_blobs = []
148+
for blob in blobs:
149+
image_query = QueryBuilder.add_command(ObjectType.IMAGE, {
150+
"properties": literals,
151+
"connect": {
152+
"ref": 1,
153+
"class": blob,
154+
"direction": "out"
155+
}
156+
})
157+
image_blobs.append(blobs[blob])
158+
dependents.append(image_query)
159+
160+
return [q] + dependents, image_blobs
161+
162+
163+
class MLCroissantRecordSet(Subscriptable):
164+
def __init__(
165+
self,
166+
record_set: mlc.Records,
167+
name: str,
168+
flatten_json: bool,
169+
sample_count: int = 0,
170+
uuid: str = None):
171+
self.record_set = record_set
172+
self.uuid = uuid
173+
samples = []
174+
count = 0
175+
for record in record_set:
176+
samples.append({k: v for k, v in record.items()})
177+
count += 1
178+
if count == sample_count:
179+
break
180+
181+
self.df = pd.json_normalize(samples)
182+
self.sample_count = len(samples)
183+
self.name = name
184+
self.flatten_json = flatten_json
185+
self.indexed_entities = set()
186+
187+
def getitem(self, subscript):
188+
row = self.df.iloc[subscript]
189+
# Convert the row to a dictionary
190+
row_dict = row.to_dict()
191+
192+
find_recordset_query = QueryBuilder.find_command(
193+
"RecordSetModel", {
194+
"_ref": MAX_REF_VALUE,
195+
"constraints": {
196+
"uuid": ["==", self.uuid]
197+
}
198+
})
199+
200+
q, blobs = dict_to_query(row_dict, self.name, self.flatten_json)
201+
indexes_to_create = []
202+
for command in q:
203+
cmd = list(command.keys())[-1]
204+
if cmd == "AddImage":
205+
continue
206+
indexable_entity = command[list(command.keys())[-1]]["class"]
207+
if indexable_entity not in self.indexed_entities:
208+
index_command = {
209+
"CreateIndex": {
210+
"class": indexable_entity,
211+
"index_type": "entity",
212+
"property_key": "adb_uuid",
213+
}
214+
}
215+
indexes_to_create.append(index_command)
216+
return indexes_to_create + [find_recordset_query] + q, blobs
217+
218+
def __len__(self):
219+
return len(self.df)

0 commit comments

Comments
 (0)