Skip to content

Commit 9bd1eab

Browse files
authored
Merge pull request #11 from aperture-data/pytorch_connectors
PyTorch connectors
2 parents f1ca6fa + 14c504e commit 9bd1eab

File tree

3 files changed

+222
-15
lines changed

3 files changed

+222
-15
lines changed

aperturedb/Images.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def rotate(self, angle, resize=False):
6868

6969
class Images(object):
7070

71-
def __init__(self, db, batch_size=20):
71+
def __init__(self, db, batch_size=100):
7272

7373
self.db_connector = db
7474

@@ -84,7 +84,6 @@ def __init__(self, db, batch_size=20):
8484
self.search_result = None
8585

8686
self.batch_size = batch_size
87-
self.max_cached_images = 1000
8887
self.total_cached_images = 0
8988
self.display_limit = 20
9089

@@ -279,7 +278,7 @@ def get_np_image_by_index(self, index):
279278

280279
image = self.get_image_by_index(index)
281280
# Just decode the image from buffer
282-
nparr = np.fromstring(image, np.uint8)
281+
nparr = np.frombuffer(image, dtype=np.uint8)
283282
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
284283

285284
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -337,7 +336,6 @@ def search(self, constraints=None, operations=None, format=None, limit=None):
337336

338337
for ent in entities:
339338
self.images_ids.append(ent[self.img_id_prop])
340-
341339
except:
342340
print("Error with search")
343341

@@ -424,7 +422,7 @@ def display(self, show_bboxes=False, show_segmentation=False, limit=None):
424422
image = self.get_image_by_index(i)
425423

426424
# Just decode the image from buffer
427-
nparr = np.fromstring(image, np.uint8)
425+
nparr = np.frombuffer(image, dtype=np.uint8)
428426
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
429427

430428
if show_bboxes:

aperturedb/PyTorchDataset.py

Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,136 @@
11
import os
2+
import math
23

3-
from aperturedb import Image
4+
import numpy as np
5+
import cv2
6+
7+
from aperturedb import Images
48

59
import torch
610
from torch.utils import data
711
from torchvision import transforms
812

9-
class ApertureDBDataset(data.Dataset):
13+
class ApertureDBDatasetConstraints(data.Dataset):
1014

1115
# initialise function of class
1216
def __init__(self, db, constraints):
1317

14-
self.imgs_handler = Image.Images(db)
15-
self.imgs_handler.search(constraints=constraints, limit=50)
18+
self.imgs_handler = Images.Images(db)
19+
self.imgs_handler.search(constraints=constraints)
1620

17-
# obtain the sample with the given index
1821
def __getitem__(self, index):
1922

20-
img = self.imgs_handler.get_np_image_by_index(index)
21-
label = self.imgs_handler.get_bboxes_by_index(index)
23+
if index >= self.imgs_handler.total_results():
24+
raise StopIteration
2225

23-
img = transforms.ToTensor()(img)
24-
# label = torch.as_tensor(label, dtype=torch.int64)
26+
img = self.imgs_handler.get_np_image_by_index(index)
2527

28+
# This is temporary until we define a good, generic way, of
29+
# retriving a label associated with the image.
30+
label = "none"
2631
return img, label
2732

28-
# the total number of samples (optional)
2933
def __len__(self):
34+
3035
return self.imgs_handler.total_results()
36+
37+
class ApertureDBDataset(data.Dataset):
38+
39+
# initialise function of class
40+
def __init__(self, db, query, label_prop=None):
41+
42+
self.db = db
43+
self.query = query
44+
self.find_image_idx = None
45+
self.total_elements = 0
46+
self.batch_size = 100
47+
self.batch_images = []
48+
self.batch_start = 0
49+
self.batch_end = 0
50+
self.label_prop = label_prop
51+
52+
for i in range(len(query)):
53+
54+
name = list(query[i].keys())[0]
55+
if name == "FindImage":
56+
self.find_image_idx = i
57+
58+
if self.find_image_idx is None:
59+
print("Query error. The query must containt one FindImage command")
60+
raise Exception('Query Error')
61+
62+
if not "results" in self.query[self.find_image_idx]["FindImage"]:
63+
self.query[self.find_image_idx]["FindImage"]["results"] = {}
64+
65+
self.query[self.find_image_idx]["FindImage"]["results"]["batch"] = {}
66+
67+
try:
68+
r,b = self.db.query(self.query)
69+
batch = r[self.find_image_idx]["FindImage"]["batch"]
70+
self.total_elements = batch["total_elements"]
71+
except:
72+
print("Query error:")
73+
print(self.query)
74+
print(self.db.get_last_response_str())
75+
raise
76+
77+
def __getitem__(self, index):
78+
79+
if index >= self.total_elements:
80+
raise StopIteration
81+
82+
if not self.is_in_range(index):
83+
self.get_batch(index)
84+
85+
idx = index % self.batch_size
86+
img = self.batch_images[idx]
87+
label = self.batch_labels[idx]
88+
89+
nparr = np.frombuffer(img, dtype=np.uint8)
90+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
91+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
92+
93+
return img, label
94+
95+
def __len__(self):
96+
97+
return self.total_elements
98+
99+
def is_in_range(self, index):
100+
101+
if index >= self.batch_start and index < self.batch_end:
102+
return True
103+
104+
return False
105+
106+
def get_batch(self, index):
107+
108+
total_batches = math.ceil(self.total_elements / self.batch_size)
109+
batch_idx = math.floor(index / self.batch_size)
110+
111+
query = self.query
112+
qbatch = query[self.find_image_idx]["FindImage"]["results"]["batch"]
113+
qbatch["batch_size"] = self.batch_size
114+
qbatch["batch_id"] = batch_idx
115+
116+
query[self.find_image_idx]["FindImage"]["results"]["batch"] = qbatch
117+
118+
try:
119+
r,b = self.db.query(query)
120+
if len(b) == 0:
121+
print("index:", index)
122+
raise Exception("No results returned from ApertureDB")
123+
124+
self.batch_images = b
125+
self.batch_start = self.batch_size * batch_idx
126+
self.batch_end = self.batch_start + len(b)
127+
128+
if self.label_prop:
129+
entities = r[self.find_image_idx]["FindImage"]["entities"]
130+
self.batch_labels = [ l[self.label_prop] for l in entities]
131+
else:
132+
self.batch_labels = [ "none" for l in range(len(b))]
133+
except:
134+
print("Query error:")
135+
print(self.db.get_last_response_str())
136+
raise

test/test_torch_connector.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import argparse
2+
import time
3+
import unittest
4+
5+
import dbinfo
6+
7+
from aperturedb import Connector, Status
8+
from aperturedb import Images
9+
from aperturedb import PyTorchDataset
10+
11+
class TestTorch(unittest.TestCase):
12+
13+
def __init__(self, *args, **kwargs):
14+
super().__init__(*args, **kwargs)
15+
16+
# ApertureDB Server Info
17+
self.db_host = dbinfo.DB_HOST
18+
self.db_port = dbinfo.DB_PORT
19+
20+
db_up = False
21+
attempts = 0
22+
while(not db_up):
23+
try:
24+
db = Connector.Connector(self.db_host, self.db_port)
25+
db_up = True
26+
if (attempts > 0):
27+
print("Connection to ApertureDB successful.")
28+
except:
29+
print("Attempt", attempts,
30+
"to connect to ApertureDB failed, retying...")
31+
attempts += 1
32+
time.sleep(1) # sleeps 1 second
33+
34+
if attempts > 10:
35+
print("Failed to connect to ApertureDB after 10 attempts")
36+
exit()
37+
38+
class TestTorchDatasets(TestTorch):
39+
40+
'''
41+
These tests need to be run after the Loaders, because it uses
42+
data inserted by the loaders.
43+
'''
44+
45+
def test_omConstraints(self):
46+
47+
db = Connector.Connector(self.db_host, self.db_port)
48+
49+
const = Images.Constraints()
50+
const.greaterequal("age", 0)
51+
dataset = PyTorchDataset.ApertureDBDatasetConstraints(db, constraints=const)
52+
53+
dbstatus = Status.Status(db)
54+
self.assertEqual(len(dataset), dbstatus.count_images())
55+
56+
start = time.time()
57+
58+
# Iterate over dataset.
59+
for img in dataset:
60+
if len(img[0]) < 0:
61+
print("Empty image?")
62+
self.assertEqual(True, False)
63+
64+
print("\n")
65+
print("Throughput (imgs/s):", len(dataset) / (time.time() - start))
66+
67+
def test_nativeContraints(self):
68+
69+
db = Connector.Connector(self.db_host, self.db_port)
70+
71+
query = [ {
72+
"FindImage": {
73+
"constraints": {
74+
"age": [">=", 0]
75+
},
76+
"operations": [
77+
{
78+
"type": "resize",
79+
"width": 224,
80+
"height": 224
81+
}
82+
],
83+
"results": {
84+
"list": ["license"]
85+
}
86+
}
87+
}]
88+
89+
dataset = PyTorchDataset.ApertureDBDataset(db, query, label_prop="license")
90+
91+
dbstatus = Status.Status(db)
92+
self.assertEqual(len(dataset), dbstatus.count_images())
93+
94+
start = time.time()
95+
96+
# Iterate over dataset.
97+
for img in dataset:
98+
if len(img[0]) < 0:
99+
print("Empty image?")
100+
self.assertEqual(True, False)
101+
102+
print("\n")
103+
print("Throughput (imgs/s):", len(dataset) / (time.time() - start))

0 commit comments

Comments
 (0)