|
1 | 1 | import os |
| 2 | +import math |
2 | 3 |
|
3 | | -from aperturedb import Image |
| 4 | +import numpy as np |
| 5 | +import cv2 |
| 6 | + |
| 7 | +from aperturedb import Images |
4 | 8 |
|
5 | 9 | import torch |
6 | 10 | from torch.utils import data |
7 | 11 | from torchvision import transforms |
8 | 12 |
|
9 | | -class ApertureDBDataset(data.Dataset): |
| 13 | +class ApertureDBDatasetConstraints(data.Dataset): |
10 | 14 |
|
11 | 15 | # initialise function of class |
12 | 16 | def __init__(self, db, constraints): |
13 | 17 |
|
14 | | - self.imgs_handler = Image.Images(db) |
15 | | - self.imgs_handler.search(constraints=constraints, limit=50) |
| 18 | + self.imgs_handler = Images.Images(db) |
| 19 | + self.imgs_handler.search(constraints=constraints) |
16 | 20 |
|
17 | | - # obtain the sample with the given index |
18 | 21 | def __getitem__(self, index): |
19 | 22 |
|
20 | | - img = self.imgs_handler.get_np_image_by_index(index) |
21 | | - label = self.imgs_handler.get_bboxes_by_index(index) |
| 23 | + if index >= self.imgs_handler.total_results(): |
| 24 | + raise StopIteration |
22 | 25 |
|
23 | | - img = transforms.ToTensor()(img) |
24 | | - # label = torch.as_tensor(label, dtype=torch.int64) |
| 26 | + img = self.imgs_handler.get_np_image_by_index(index) |
25 | 27 |
|
| 28 | + # This is temporary until we define a good, generic way, of |
| 29 | + # retriving a label associated with the image. |
| 30 | + label = "none" |
26 | 31 | return img, label |
27 | 32 |
|
28 | | - # the total number of samples (optional) |
29 | 33 | def __len__(self): |
| 34 | + |
30 | 35 | return self.imgs_handler.total_results() |
| 36 | + |
| 37 | +class ApertureDBDataset(data.Dataset): |
| 38 | + |
| 39 | + # initialise function of class |
| 40 | + def __init__(self, db, query, label_prop=None): |
| 41 | + |
| 42 | + self.db = db |
| 43 | + self.query = query |
| 44 | + self.find_image_idx = None |
| 45 | + self.total_elements = 0 |
| 46 | + self.batch_size = 100 |
| 47 | + self.batch_images = [] |
| 48 | + self.batch_start = 0 |
| 49 | + self.batch_end = 0 |
| 50 | + self.label_prop = label_prop |
| 51 | + |
| 52 | + for i in range(len(query)): |
| 53 | + |
| 54 | + name = list(query[i].keys())[0] |
| 55 | + if name == "FindImage": |
| 56 | + self.find_image_idx = i |
| 57 | + |
| 58 | + if self.find_image_idx is None: |
| 59 | + print("Query error. The query must containt one FindImage command") |
| 60 | + raise Exception('Query Error') |
| 61 | + |
| 62 | + if not "results" in self.query[self.find_image_idx]["FindImage"]: |
| 63 | + self.query[self.find_image_idx]["FindImage"]["results"] = {} |
| 64 | + |
| 65 | + self.query[self.find_image_idx]["FindImage"]["results"]["batch"] = {} |
| 66 | + |
| 67 | + try: |
| 68 | + r,b = self.db.query(self.query) |
| 69 | + batch = r[self.find_image_idx]["FindImage"]["batch"] |
| 70 | + self.total_elements = batch["total_elements"] |
| 71 | + except: |
| 72 | + print("Query error:") |
| 73 | + print(self.query) |
| 74 | + print(self.db.get_last_response_str()) |
| 75 | + raise |
| 76 | + |
| 77 | + def __getitem__(self, index): |
| 78 | + |
| 79 | + if index >= self.total_elements: |
| 80 | + raise StopIteration |
| 81 | + |
| 82 | + if not self.is_in_range(index): |
| 83 | + self.get_batch(index) |
| 84 | + |
| 85 | + idx = index % self.batch_size |
| 86 | + img = self.batch_images[idx] |
| 87 | + label = self.batch_labels[idx] |
| 88 | + |
| 89 | + nparr = np.frombuffer(img, dtype=np.uint8) |
| 90 | + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| 91 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| 92 | + |
| 93 | + return img, label |
| 94 | + |
| 95 | + def __len__(self): |
| 96 | + |
| 97 | + return self.total_elements |
| 98 | + |
| 99 | + def is_in_range(self, index): |
| 100 | + |
| 101 | + if index >= self.batch_start and index < self.batch_end: |
| 102 | + return True |
| 103 | + |
| 104 | + return False |
| 105 | + |
| 106 | + def get_batch(self, index): |
| 107 | + |
| 108 | + total_batches = math.ceil(self.total_elements / self.batch_size) |
| 109 | + batch_idx = math.floor(index / self.batch_size) |
| 110 | + |
| 111 | + query = self.query |
| 112 | + qbatch = query[self.find_image_idx]["FindImage"]["results"]["batch"] |
| 113 | + qbatch["batch_size"] = self.batch_size |
| 114 | + qbatch["batch_id"] = batch_idx |
| 115 | + |
| 116 | + query[self.find_image_idx]["FindImage"]["results"]["batch"] = qbatch |
| 117 | + |
| 118 | + try: |
| 119 | + r,b = self.db.query(query) |
| 120 | + if len(b) == 0: |
| 121 | + print("index:", index) |
| 122 | + raise Exception("No results returned from ApertureDB") |
| 123 | + |
| 124 | + self.batch_images = b |
| 125 | + self.batch_start = self.batch_size * batch_idx |
| 126 | + self.batch_end = self.batch_start + len(b) |
| 127 | + |
| 128 | + if self.label_prop: |
| 129 | + entities = r[self.find_image_idx]["FindImage"]["entities"] |
| 130 | + self.batch_labels = [ l[self.label_prop] for l in entities] |
| 131 | + else: |
| 132 | + self.batch_labels = [ "none" for l in range(len(b))] |
| 133 | + except: |
| 134 | + print("Query error:") |
| 135 | + print(self.db.get_last_response_str()) |
| 136 | + raise |
0 commit comments