Skip to content

Commit 123effc

Browse files
authored
Merge pull request #596 from aperture-data/release-0.4.48
- Brings in bug fixes in using APERTUREDB_KEY. - Addresses some errors from CI process - ingest from_croissant is available as adb command, and MLCroissant as an SDK class.
2 parents 1db424c + c8bca94 commit 123effc

File tree

12 files changed

+436
-44
lines changed

12 files changed

+436
-44
lines changed

aperturedb/CommonLibrary.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,12 @@ def create_connector(
164164
This function chooses a configuration in the folowing order:
165165
1. The configuration named by the `name` parameter or `key` parameter
166166
2. The configuration described in the `APERTUREDB_KEY` environment variable.
167-
3. The configuration described in the `APERTUREDB_JSON` environment variable.
168-
4. The configuration described in the `APERTUREDB_JSON` Google Colab secret.
169-
5. The configuration described in the `APERTUREDB_JSON` secret in a `.env` file.
170-
6. The configuration named by the `APERTUREDB_CONFIG` environment variable.
171-
7. The active configuration.
167+
3. The configuration described in the `APERTUREDB_KEY` Google Colab secret.
168+
4. The configuration described in the `APERTUREDB_JSON` environment variable.
169+
5. The configuration described in the `APERTUREDB_JSON` Google Colab secret.
170+
6. The configuration described in the `APERTUREDB_JSON` secret in a `.env` file.
171+
7. The configuration named by the `APERTUREDB_CONFIG` environment variable.
172+
8. The active configuration.
172173
173174
If there are both global and local configurations with the same name, the global configuration is preferred.
174175
@@ -214,6 +215,14 @@ def lookup_config_by_name(name: str, source: str) -> Configuration:
214215
logger.info(
215216
f"Using configuration from APERTUREDB_KEY environment variable")
216217
config = Configuration.reinflate(data)
218+
elif (data := _get_colab_secret("APERTUREDB_KEY")) is not None and data != "":
219+
logger.info(
220+
f"Using configuration from APERTUREDB_KEY Google Colab secret")
221+
config = Configuration.reinflate(data)
222+
if create_config_for_colab_secret:
223+
logger.info(
224+
f"Creating and activating configuration from APERTUREDB_KEY Google Colab secret")
225+
_store_config(config, 'google_colab')
217226
elif (data := os.environ.get("APERTUREDB_JSON")) is not None and data != "":
218227
logger.info(
219228
f"Using configuration from APERTUREDB_JSON environment variable")
@@ -315,7 +324,7 @@ def execute_query(client: Connector, query: Commands,
315324
warn_list.append(wr)
316325
if len(warn_list) != 0:
317326
logger.warning(
318-
f"Partial errors:\r\n{json.dumps(query)}\r\n{json.dumps(warn_list)}")
327+
f"Partial errors:\r\n{json.dumps(query, default=str)}\r\n{json.dumps(warn_list, default=str)}")
319328
result = 2
320329

321330
return result, r, b

aperturedb/Connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _query(self, query, blob_array = [], try_resume=True):
376376
response_blob_array = []
377377
# Check the query type
378378
if not isinstance(query, str): # assumes json
379-
query_str = json.dumps(query)
379+
query_str = json.dumps(query, default=str)
380380
else:
381381
query_str = query
382382

aperturedb/ConnectorRest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def __init__(self, host="localhost", port=None,
7777
# A Convenience feature to not require the port
7878
# Relies on common ports for http and https, but can be overriden
7979
if port is None:
80-
self.port = 443 if use_ssl else 80
80+
if config is None:
81+
self.port = 443 if use_ssl else 80
82+
else:
83+
self.port = config.port
8184
else:
8285
self.port = port
8386

aperturedb/EntityUpdateDataCSV.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class SingleEntityUpdateDataCSV(CSVParser.CSVParser):
2525
- a series of updateif_ to determine if an update is necessary
2626
2727
Conditionals:
28-
updateif>_prop - updates if the database value > csv value
29-
updateif<_prop - updates if the database value < csv value
30-
updateif!_prop - updates if the database value is != csv value
28+
```updateif>_prop``` : updates if the database value greater than csv value
29+
```updateif<_prop``` : updates if the database value less than csv value
30+
```updateif!_prop``` : updates if the database value is not equal to csv value
3131
3232
:::note
3333
Is backed by a CSV file with the following columns (format optional):

aperturedb/MLCroissant.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import dataclasses
2+
import hashlib
3+
import io
4+
import json
5+
import logging
6+
import PIL.GifImagePlugin
7+
import mlcroissant as mlc
8+
import PIL.Image
9+
import pandas as pd
10+
11+
from typing import Any, List, Tuple
12+
13+
from aperturedb.Subscriptable import Subscriptable
14+
from aperturedb.Query import QueryBuilder
15+
from aperturedb.DataModels import IdentityDataModel
16+
from aperturedb.Query import generate_add_query
17+
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
MAX_REF_VALUE = 99999
23+
# This is useful to identify the class of the record in ApertureDB.
24+
CLASS_PROPERTY_NAME = "adb_class_name"
25+
26+
27+
class RecordSetModel(IdentityDataModel):
28+
name: str
29+
description: str = ""
30+
uuid: str = ""
31+
32+
33+
class DatasetModel(IdentityDataModel):
34+
url: str = ""
35+
name: str = "Croissant Dataset automatically ingested into ApertureDB"
36+
description: str = f"A dataset loaded from a croissant json-ld"
37+
version: str = "1.0.0"
38+
record_sets: List[RecordSetModel] = dataclasses.field(default_factory=list)
39+
40+
41+
def deserialize_record(record):
42+
"""These are the types of records that we expect to deserialize, from croissant Records.
43+
44+
Args:
45+
record (_type_): _description_
46+
47+
Returns:
48+
_type_: _description_
49+
"""
50+
deserialized = record
51+
if record is None:
52+
deserialized = "Not Available"
53+
if isinstance(record, bytes):
54+
deserialized = record.decode('utf-8')
55+
if isinstance(record, pd.Timestamp):
56+
deserialized = {"_date": record.to_pydatetime().isoformat()}
57+
if record == pd.NaT:
58+
deserialized = "Not Available Time"
59+
if isinstance(deserialized, str):
60+
if deserialized.startswith("[") or deserialized.startswith("{"):
61+
# If it looks like a list or dict, try to parse it as JSON
62+
try:
63+
deserialized = json.loads(deserialized)
64+
except json.JSONDecodeError:
65+
logger.info(f"Failed to parse JSON: {deserialized}")
66+
67+
try:
68+
deserialized = json.loads(deserialized.replace("'", "\""))
69+
except Exception as e:
70+
logger.info(
71+
f"Failed to parse JSON: {deserialized} with error {e}")
72+
pass
73+
74+
if isinstance(deserialized, list):
75+
deserialized = [deserialize_record(item) for item in deserialized]
76+
if isinstance(deserialized, dict):
77+
deserialized = {k: deserialize_record(
78+
v) for k, v in deserialized.items()}
79+
80+
return deserialized
81+
82+
83+
def persist_metadata(dataset: mlc.Dataset, url: str) -> Tuple[List[dict], List[bytes]]:
84+
"""
85+
Persist the metadata of a croissant dataset into ApertureDB.
86+
"""
87+
ds = DatasetModel(
88+
url=url,
89+
name=dataset.metadata.name,
90+
description=dataset.metadata.description,
91+
version=dataset.metadata.version or "1.0.0",
92+
record_sets=[RecordSetModel(
93+
name=rs.name,
94+
description=rs.description,
95+
uuid=rs.uuid,
96+
) for rs in dataset.metadata.record_sets]
97+
)
98+
q, b, _ = generate_add_query(ds)
99+
100+
return q, b
101+
102+
103+
def try_parse(value: str) -> Any:
104+
"""Attempts to parse a string value into a more appropriate type."""
105+
parsed = value.strip()
106+
107+
if parsed.startswith("http"):
108+
# Download the content from the URL
109+
from aperturedb.Sources import Sources
110+
sources = Sources(n_download_retries=3)
111+
result, buffer = sources.load_from_http_url(
112+
parsed, validator=lambda x: True)
113+
if result:
114+
parsed = PIL.Image.open(io.BytesIO(buffer))
115+
116+
return parsed
117+
118+
119+
def dict_to_query(row_dict, name: str, flatten_json: bool) -> Any:
120+
literals = {}
121+
subitems = {}
122+
known_image_blobs = {}
123+
unknown_blobs = {}
124+
o_literals = {}
125+
126+
name = name.split("/")[-1] # Use the last part of the name
127+
# If name is not specified, or begins with _, this ensures that it
128+
# complies with the ApertureDB naming conventions
129+
if not name or name.startswith("_"):
130+
safe_name = f"E_{name or 'Record'}" # Uncomment if you want
131+
logger.warning(
132+
f"Entity Name '{name}' is not valid. Using {safe_name}.")
133+
name = safe_name
134+
135+
for k, v in row_dict.items():
136+
k = k.split("/")[-1] # Use the last part of the key
137+
if not k or k.startswith("_"):
138+
safe_key = f"F_{k or 'Field'}"
139+
logger.warning(
140+
f"Property name '{k}' is not valid. Using {safe_key}.")
141+
k = safe_key
142+
item = v
143+
# Pre processed items from croissant.
144+
if isinstance(item, PIL.Image.Image):
145+
buffer = io.BytesIO()
146+
item.save(buffer, format=item.format)
147+
known_image_blobs[k] = buffer.getvalue()
148+
continue
149+
150+
record = deserialize_record(item)
151+
if isinstance(record, str):
152+
record = try_parse(record)
153+
154+
# Post processed items from SDK.
155+
if isinstance(record, PIL.GifImagePlugin.GifImageFile):
156+
buffer = io.BytesIO()
157+
record.save(buffer, format=record.format)
158+
unknown_blobs[k] = buffer.getvalue()
159+
continue
160+
161+
if isinstance(record, PIL.Image.Image):
162+
buffer = io.BytesIO()
163+
record.save(buffer, format=record.format)
164+
known_image_blobs[k] = buffer.getvalue()
165+
continue
166+
167+
if flatten_json and isinstance(record, list):
168+
subitems[k] = record
169+
else:
170+
literals[k] = record
171+
# Original value from croissant. This is useful for debugging.
172+
o_literals[k] = item
173+
174+
if flatten_json:
175+
str_rep = "".join([f"{str(k)}{str(v)}" for k, v in literals.items()])
176+
literals["adb_uuid"] = hashlib.sha256(
177+
str_rep.encode('utf-8')).hexdigest()
178+
179+
literals[CLASS_PROPERTY_NAME] = name
180+
q = QueryBuilder.add_command(name, {
181+
"properties": literals,
182+
"connect": {
183+
"ref": MAX_REF_VALUE,
184+
"class": "hasRecord",
185+
"direction": "in",
186+
}
187+
})
188+
if flatten_json:
189+
q[list(q.keys())[-1]]["if_not_found"] = {
190+
"adb_uuid": ["==", literals["adb_uuid"]]
191+
}
192+
193+
dependents = []
194+
if len(subitems) > 0 or len(known_image_blobs) > 0 or len(unknown_blobs) > 0:
195+
# We need to create a reference to this record
196+
q[list(q.keys())[-1]]["_ref"] = 1
197+
198+
for key in subitems:
199+
for item in subitems[key]:
200+
subitem_query, blobs = dict_to_query(
201+
item, f"{name}.{key}", flatten_json)
202+
subitem_query[0][list(subitem_query[0].keys())[-1]]["connect"] = {
203+
"ref": 1,
204+
"class": key,
205+
"direction": "in",
206+
}
207+
dependents.extend(subitem_query)
208+
209+
from aperturedb.Query import ObjectType
210+
blobs = []
211+
for blob in known_image_blobs:
212+
image_query = QueryBuilder.add_command(ObjectType.IMAGE, {
213+
"properties": {CLASS_PROPERTY_NAME: literals[CLASS_PROPERTY_NAME] + "." + "image"},
214+
"connect": {
215+
"ref": 1,
216+
"class": blob,
217+
"direction": "in"
218+
}
219+
})
220+
blobs.append(known_image_blobs[blob])
221+
dependents.append(image_query)
222+
223+
for blob in unknown_blobs:
224+
blob_query = QueryBuilder.add_command(ObjectType.BLOB, {
225+
"properties": {CLASS_PROPERTY_NAME: literals[CLASS_PROPERTY_NAME] + "." + "blob"},
226+
"connect": {
227+
"ref": 1,
228+
"class": blob,
229+
"direction": "in"
230+
}
231+
})
232+
blobs.append(unknown_blobs[blob])
233+
dependents.append(blob_query)
234+
235+
return [q] + dependents, blobs
236+
237+
238+
class MLCroissantRecordSet(Subscriptable):
239+
def __init__(
240+
self,
241+
record_set: mlc.Records,
242+
name: str,
243+
flatten_json: bool,
244+
sample_count: int = 0,
245+
uuid: str = None):
246+
self.record_set = record_set
247+
self.uuid = uuid
248+
samples = []
249+
count = 0
250+
for record in record_set:
251+
samples.append({k: v for k, v in record.items()})
252+
count += 1
253+
if count == sample_count:
254+
break
255+
256+
self.samples = samples
257+
self.sample_count = len(samples)
258+
self.name = name
259+
self.flatten_json = flatten_json
260+
self.indexed_entities = set()
261+
262+
def getitem(self, subscript):
263+
row_dict = self.samples[subscript]
264+
265+
find_recordset_query = QueryBuilder.find_command(
266+
"RecordSetModel", {
267+
"_ref": MAX_REF_VALUE,
268+
"constraints": {
269+
"uuid": ["==", self.uuid]
270+
}
271+
})
272+
273+
q, blobs = dict_to_query(row_dict, self.name, self.flatten_json)
274+
indexes_to_create = []
275+
for command in q:
276+
cmd = list(command.keys())[-1]
277+
if cmd in ["AddImage", "AddBlob", "AddVideo"]:
278+
continue
279+
indexable_entity = command[list(command.keys())[-1]]["class"]
280+
if indexable_entity not in self.indexed_entities:
281+
index_command = {
282+
"CreateIndex": {
283+
"class": indexable_entity,
284+
"index_type": "entity",
285+
"property_key": "adb_uuid",
286+
}
287+
}
288+
indexes_to_create.append(index_command)
289+
return indexes_to_create + [find_recordset_query] + q, blobs
290+
291+
def __len__(self):
292+
return len(self.samples)

aperturedb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import signal
1111
import sys
1212

13-
__version__ = "0.4.47"
13+
__version__ = "0.4.48"
1414

1515
logger = logging.getLogger(__name__)
1616

aperturedb/cli/configure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,12 @@ def check_for_overwrite(name):
198198
gen_config = _create_configuration_from_json(
199199
json_str, name=name, name_required=True)
200200
check_for_overwrite(gen_config.name)
201-
name = gen_config.name
201+
name = name if name is not None else gen_config.name
202202
elif from_key:
203203
assert interactive, "Interactive mode must be enabled for --from-key"
204204
encoded_str = typer.prompt("Enter encoded string", hide_input=True)
205205
gen_config = Configuration.reinflate(encoded_str)
206-
name = gen_config.name
206+
name = name if name is not None else gen_config.name
207207

208208
else:
209209
if interactive:

0 commit comments

Comments
 (0)