Skip to content

Commit 31f7492

Browse files
authored
Merge pull request #38 from aperture-data/release-0.1.1
Release 0.1.1
2 parents 0dd9204 + 63cc506 commit 31f7492

File tree

14 files changed

+381
-44
lines changed

14 files changed

+381
-44
lines changed

.github/workflows/main.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ jobs:
4040
password: ${{ secrets.DOCKER_PASS }}
4141

4242
- name: Run Tests
43+
env:
44+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
45+
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
46+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
4347
run: |
4448
pip3 install .
45-
pip3 install ipython torch torchvision
49+
pip3 install ipython torch torchvision boto3
4650
cd test
4751
bash run_test.sh
4852

aperturedb/CSVParser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def parse_properties(self, df, idx):
4747
prop = key[len("date:"):] # remove prefix
4848
properties[prop] = {"_date": self.df.loc[idx, key]}
4949
else:
50-
properties[key] = self.df.loc[idx, key]
50+
value = self.df.loc[idx, key]
51+
if value == value: # skips nan values
52+
properties[key] = value
5153

5254
return properties
5355

aperturedb/ImageLoader.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import math
22
import time
3+
import requests
4+
import os
5+
import boto3
36
from threading import Thread
47

58
import numpy as np
@@ -9,10 +12,12 @@
912
from aperturedb import ParallelLoader
1013
from aperturedb import CSVParser
1114

12-
HEADER_PATH = "filename"
13-
PROPERTIES = "properties"
14-
CONSTRAINTS = "constraints"
15-
IMG_FORMAT = "format"
15+
HEADER_PATH = "filename"
16+
HEADER_URL = "url"
17+
HEADER_S3_URL = "s3_url"
18+
PROPERTIES = "properties"
19+
CONSTRAINTS = "constraints"
20+
IMG_FORMAT = "format"
1621

1722
class ImageGeneratorCSV(CSVParser.CSVParser):
1823

@@ -21,6 +26,11 @@ class ImageGeneratorCSV(CSVParser.CSVParser):
2126
Expects a csv file with the following columns (format optional):
2227
2328
filename,PROP_NAME_1, ... PROP_NAME_N,constraint_PROP1,format
29+
OR
30+
url,PROP_NAME_1, ... PROP_NAME_N,constraint_PROP1,format
31+
OR
32+
s3_url,PROP_NAME_1, ... PROP_NAME_N,constraint_PROP1,format
33+
...
2434
2535
Example csv file:
2636
filename,id,label,constaint_id,format
@@ -29,7 +39,7 @@ class ImageGeneratorCSV(CSVParser.CSVParser):
2939
...
3040
'''
3141

32-
def __init__(self, filename, check_image=True):
42+
def __init__(self, filename, check_image=True, n_download_retries=3):
3343

3444
super().__init__(filename)
3545

@@ -40,13 +50,31 @@ def __init__(self, filename, check_image=True):
4050
self.props_keys = [x for x in self.props_keys if x != IMG_FORMAT]
4151
self.constraints_keys = [x for x in self.header[1:] if x.startswith(CSVParser.CONTRAINTS_PREFIX) ]
4252

53+
self.source_type = self.header[0]
54+
if self.source_type not in [ HEADER_PATH, HEADER_URL, HEADER_S3_URL ]:
55+
print("Source not recognized: " + self.source_type)
56+
raise Exception("Error loading image: " + filename )
57+
58+
self.n_download_retries = n_download_retries
59+
4360
# TODO: we can add support for slicing here.
4461
def __getitem__(self, idx):
4562

46-
filename = self.df.loc[idx, HEADER_PATH]
4763
data = {}
4864

49-
img_ok, img = self.load_image(filename)
65+
img_ok = True
66+
img = None
67+
68+
if self.source_type == HEADER_PATH:
69+
image_path = self.df.loc[idx, HEADER_PATH]
70+
img_ok, img = self.load_image(image_path)
71+
elif self.source_type == HEADER_URL:
72+
image_path = self.df.loc[idx, HEADER_URL]
73+
img_ok, img = self.load_url(image_path)
74+
elif self.source_type == HEADER_S3_URL:
75+
image_path = self.df.loc[idx, HEADER_S3_URL]
76+
img_ok, img = self.load_s3_url(image_path)
77+
5078
if not img_ok:
5179
print("Error loading image: " + filename )
5280
raise Exception("Error loading image: " + filename )
@@ -67,12 +95,12 @@ def __getitem__(self, idx):
6795
return data
6896

6997
def load_image(self, filename):
70-
7198
if self.check_image:
7299
try:
73100
a = cv2.imread(filename)
74101
if a.size <= 0:
75102
print("IMAGE SIZE ERROR:", filename)
103+
return false, None
76104
except:
77105
print("IMAGE ERROR:", filename)
78106

@@ -83,14 +111,73 @@ def load_image(self, filename):
83111
return True, buff
84112
except:
85113
print("IMAGE ERROR:", filename)
114+
return False, None
115+
116+
def check_image_buffer(self, img):
117+
try:
118+
decoded_img = cv2.imdecode(img, cv2.IMREAD_COLOR)
119+
120+
# Check image is correct
121+
decoded_img = decoded_img if decoded_img is not None else img
122+
123+
return True
124+
except:
125+
return False
126+
127+
def load_url(self, url):
128+
retries = 0
129+
while True:
130+
imgdata = requests.get(url)
131+
if imgdata.ok:
132+
imgbuffer = np.frombuffer(imgdata.content, dtype='uint8')
133+
if self.check_image and not self.check_image_buffer(imgbuffer):
134+
print("IMAGE ERROR: ", url)
135+
return False, None
136+
137+
return imgdata.ok, imgdata.content
138+
else:
139+
if retries >= self.n_download_retries:
140+
break
141+
print("WARNING: Retrying object:", url)
142+
retries += 1
143+
time.sleep(2)
144+
145+
return False, None
146+
147+
def load_s3_url(self, s3_url):
148+
retries = 0
149+
150+
# The connections by boto3 cause ResourceWarning. Known
151+
# issue: https://github.com/boto/boto3/issues/454
152+
s3 = boto3.client('s3')
153+
154+
while True:
155+
try:
156+
bucket_name = s3_url.split("/")[2]
157+
object_name = s3_url.split("s3://" + bucket_name + "/")[-1]
158+
s3_response_object = s3.get_object(Bucket=bucket_name, Key=object_name)
159+
img = s3_response_object['Body'].read()
160+
imgbuffer = np.frombuffer(img, dtype='uint8')
161+
if self.check_image and not self.check_image_buffer(imgbuffer):
162+
print("IMAGE ERROR: ", s3_url)
163+
return False, None
164+
165+
return True, img
166+
except:
167+
if retries >= self.n_download_retries:
168+
break
169+
print("WARNING: Retrying object:", s3_url)
170+
retries += 1
171+
time.sleep(2)
86172

173+
print("S3 ERROR:", s3_url)
87174
return False, None
88175

89176
def validate(self):
90177

91178
self.header = list(self.df.columns.values)
92179

93-
if self.header[0] != HEADER_PATH:
180+
if self.header[0] not in [ HEADER_PATH, HEADER_URL, HEADER_S3_URL ]:
94181
raise Exception("Error with CSV file field: filename. Must be first field")
95182

96183
class ImageLoader(ParallelLoader.ParallelLoader):

aperturedb/ParallelLoader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def ingest(self, generator, batchsize=1, numthreads=1, stats=False):
8585

8686
start_time = time.time()
8787

88-
elements_per_thread = math.ceil(self.total_elements / self.numthreads)
88+
if self.total_elements < batchsize:
89+
elements_per_thread = self.total_elements
90+
self.numthreads = 1
91+
else:
92+
elements_per_thread = math.ceil(self.total_elements / self.numthreads)
8993

9094
thread_arr = []
9195
for i in range(self.numthreads):

aperturedb/PyTorchDataset.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from torch.utils import data
1111
from torchvision import transforms
1212

13+
DEFAULT_BATCH_SIZE = 50
14+
1315
class ApertureDBDatasetConstraints(data.Dataset):
1416

1517
# initialise function of class
@@ -39,16 +41,19 @@ class ApertureDBDataset(data.Dataset):
3941
# initialise function of class
4042
def __init__(self, db, query, label_prop=None):
4143

42-
self.db = db
44+
self.db = db.create_new_connection()
4345
self.query = query
4446
self.find_image_idx = None
4547
self.total_elements = 0
46-
self.batch_size = 100
48+
self.batch_size = DEFAULT_BATCH_SIZE
4749
self.batch_images = []
4850
self.batch_start = 0
4951
self.batch_end = 0
5052
self.label_prop = label_prop
5153

54+
self.prev_requested = -1
55+
self.sequence_counter = DEFAULT_BATCH_SIZE
56+
5257
for i in range(len(query)):
5358

5459
name = list(query[i].keys())[0]
@@ -76,6 +81,18 @@ def __init__(self, db, query, label_prop=None):
7681

7782
def __getitem__(self, index):
7883

84+
if index == self.prev_requested + 1:
85+
self.sequence_counter += 1
86+
else:
87+
self.sequence_counter = 0
88+
89+
if self.sequence_counter >= DEFAULT_BATCH_SIZE:
90+
self.batch_size = DEFAULT_BATCH_SIZE
91+
else:
92+
self.batch_size = 1
93+
94+
self.prev_requested = index
95+
7996
if index >= self.total_elements:
8097
raise StopIteration
8198

@@ -116,7 +133,21 @@ def get_batch(self, index):
116133
query[self.find_image_idx]["FindImage"]["batch"] = qbatch
117134

118135
try:
119-
r,b = self.db.query(query)
136+
137+
# This is to handle potential issues with
138+
# disconnection/timeout and SSL context on multiprocessing
139+
connection_ok = False
140+
try:
141+
r,b = self.db.query(query)
142+
connection_ok = True
143+
except:
144+
# Connection failed, we retry just once to re-connect
145+
self.db = self.db.create_new_connection()
146+
147+
if not connection_ok:
148+
# Connection failed, we have reconnected, we try again.
149+
r,b = self.db.query(query)
150+
120151
if len(b) == 0:
121152
print("index:", index)
122153
raise Exception("No results returned from ApertureDB")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="aperturedb",
8-
version="0.1.0",
8+
version="0.1.1",
99
description="ApertureDB Client Module",
1010
install_requires=['vdms', 'scikit-image', 'image',
1111
'opencv-python', 'numpy', 'matplotlib', 'pandas'],

test/aperturedb/config.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
"pmgd_disk_sync_option": "never_sync",
1313

14-
// Serialize graph access
15-
"serialize_graph_access": true,
16-
1714
"create_parameters": {
1815
"pmgd_num_allocators": 32,
1916
"pmgd_journal_size": 1024

test/dbinfo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# This file containts information on to access the server
22

3-
DB_HOST="localhost"
4-
DB_PORT=55555
3+
DB_HOST = "localhost"
4+
DB_PORT = 55555
5+
DB_USER = "admin"
6+
DB_PASSWORD = "admin"

test/generateInput.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,52 @@ def generate_images_csv(multiplier):
9999

100100
return df
101101

102+
def generate_http_images_csv(ip_file_csv):
103+
104+
images = pd.read_csv(ip_file_csv, sep=",", header=None)
105+
106+
ids = [int(1000000000* random.random()) for i in range(len(images))]
107+
age = [int(100* random.random()) for i in range(len(images))]
108+
height = [float(200* random.random()) for i in range(len(images))]
109+
license = [x for x in range(len(images))]
110+
111+
df = pd.DataFrame()
112+
df['url'] = images
113+
df["urlid"] = ids
114+
df['license'] = license
115+
df["age"] = age
116+
df["height"] = height
117+
df["constraint_urlid"] = ids
118+
119+
df = df.sort_values("urlid")
120+
121+
df.to_csv("input/http_images.adb.csv", index=False)
122+
123+
return df
124+
125+
def generate_s3_images_csv(ip_file_csv):
126+
127+
images = pd.read_csv(ip_file_csv, sep=",", header=None)
128+
129+
ids = [int(1000000000* random.random()) for i in range(len(images))]
130+
age = [int(100* random.random()) for i in range(len(images))]
131+
height = [float(200* random.random()) for i in range(len(images))]
132+
license = [x for x in range(len(images))]
133+
134+
df = pd.DataFrame()
135+
df['s3_url'] = images
136+
df["id"] = ids
137+
df['license'] = license
138+
df["age"] = age
139+
df["height"] = height
140+
df["constraint_id"] = ids
141+
142+
df = df.sort_values("id")
143+
144+
df.to_csv("input/s3_images.adb.csv", index=False)
145+
146+
return df
147+
102148
def generate_connections_csv(persons, images):
103149

104150
connections = list(product(images["id"][::100], persons["id"][::100]))
@@ -183,6 +229,8 @@ def main(params):
183229
persons = generate_person_csv(params.multiplier)
184230
blobs = generate_blobs_csv()
185231
images = generate_images_csv(int(params.multiplier/2))
232+
s3_imgs = generate_http_images_csv("input/sample_http_urls.csv")
233+
s3_imgs = generate_s3_images_csv("input/sample_s3_urls.csv")
186234
connect = generate_connections_csv(persons, images)
187235
bboxes = generate_bboxes_csv(images)
188236

test/input/sample_http_urls.csv

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/1002318269_97db6e0975.jpg
2+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/10201275523_3e6ea67c7f.jpg
3+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/2297552664_1ee0e8855d.jpg
4+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/4140939180_07aeded917.jpg
5+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/4436463882_b96a3d9df9.jpg
6+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/4572998878_658b45226f.jpg
7+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/6985418911_df7747990d.jpg
8+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/7289030198_1f1ba44113.jpg
9+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/9329902958_0bc80ce58a.jpg
10+
https://aperturedata-public.s3.us-west-2.amazonaws.com/sample_images/9506922316_c19019e38f.jpg

0 commit comments

Comments
 (0)