11import math
22import time
3+ import requests
4+ import os
5+ import boto3
36from threading import Thread
47
58import numpy as np
912from aperturedb import ParallelLoader
1013from aperturedb import CSVParser
1114
12- HEADER_PATH = "filename"
13- PROPERTIES = "properties"
14- CONSTRAINTS = "constraints"
15- IMG_FORMAT = "format"
15+ HEADER_PATH = "filename"
16+ HEADER_URL = "url"
17+ HEADER_S3_URL = "s3_url"
18+ PROPERTIES = "properties"
19+ CONSTRAINTS = "constraints"
20+ IMG_FORMAT = "format"
1621
1722class ImageGeneratorCSV (CSVParser .CSVParser ):
1823
@@ -21,6 +26,11 @@ class ImageGeneratorCSV(CSVParser.CSVParser):
2126 Expects a csv file with the following columns (format optional):
2227
2328 filename,PROP_NAME_1, ... PROP_NAME_N,constraint_PROP1,format
29+ OR
30+ url,PROP_NAME_1, ... PROP_NAME_N,constraint_PROP1,format
31+ OR
32+ s3_url,PROP_NAME_1, ... PROP_NAME_N,constraint_PROP1,format
33+ ...
2434
2535 Example csv file:
2636 filename,id,label,constaint_id,format
@@ -29,7 +39,7 @@ class ImageGeneratorCSV(CSVParser.CSVParser):
2939 ...
3040 '''
3141
32- def __init__ (self , filename , check_image = True ):
42+ def __init__ (self , filename , check_image = True , n_download_retries = 3 ):
3343
3444 super ().__init__ (filename )
3545
@@ -40,13 +50,31 @@ def __init__(self, filename, check_image=True):
4050 self .props_keys = [x for x in self .props_keys if x != IMG_FORMAT ]
4151 self .constraints_keys = [x for x in self .header [1 :] if x .startswith (CSVParser .CONTRAINTS_PREFIX ) ]
4252
53+ self .source_type = self .header [0 ]
54+ if self .source_type not in [ HEADER_PATH , HEADER_URL , HEADER_S3_URL ]:
55+ print ("Source not recognized: " + self .source_type )
56+ raise Exception ("Error loading image: " + filename )
57+
58+ self .n_download_retries = n_download_retries
59+
4360 # TODO: we can add support for slicing here.
4461 def __getitem__ (self , idx ):
4562
46- filename = self .df .loc [idx , HEADER_PATH ]
4763 data = {}
4864
49- img_ok , img = self .load_image (filename )
65+ img_ok = True
66+ img = None
67+
68+ if self .source_type == HEADER_PATH :
69+ image_path = self .df .loc [idx , HEADER_PATH ]
70+ img_ok , img = self .load_image (image_path )
71+ elif self .source_type == HEADER_URL :
72+ image_path = self .df .loc [idx , HEADER_URL ]
73+ img_ok , img = self .load_url (image_path )
74+ elif self .source_type == HEADER_S3_URL :
75+ image_path = self .df .loc [idx , HEADER_S3_URL ]
76+ img_ok , img = self .load_s3_url (image_path )
77+
5078 if not img_ok :
5179 print ("Error loading image: " + filename )
5280 raise Exception ("Error loading image: " + filename )
@@ -67,12 +95,12 @@ def __getitem__(self, idx):
6795 return data
6896
6997 def load_image (self , filename ):
70-
7198 if self .check_image :
7299 try :
73100 a = cv2 .imread (filename )
74101 if a .size <= 0 :
75102 print ("IMAGE SIZE ERROR:" , filename )
103+ return false , None
76104 except :
77105 print ("IMAGE ERROR:" , filename )
78106
@@ -83,14 +111,73 @@ def load_image(self, filename):
83111 return True , buff
84112 except :
85113 print ("IMAGE ERROR:" , filename )
114+ return False , None
115+
116+ def check_image_buffer (self , img ):
117+ try :
118+ decoded_img = cv2 .imdecode (img , cv2 .IMREAD_COLOR )
119+
120+ # Check image is correct
121+ decoded_img = decoded_img if decoded_img is not None else img
122+
123+ return True
124+ except :
125+ return False
126+
127+ def load_url (self , url ):
128+ retries = 0
129+ while True :
130+ imgdata = requests .get (url )
131+ if imgdata .ok :
132+ imgbuffer = np .frombuffer (imgdata .content , dtype = 'uint8' )
133+ if self .check_image and not self .check_image_buffer (imgbuffer ):
134+ print ("IMAGE ERROR: " , url )
135+ return False , None
136+
137+ return imgdata .ok , imgdata .content
138+ else :
139+ if retries >= self .n_download_retries :
140+ break
141+ print ("WARNING: Retrying object:" , url )
142+ retries += 1
143+ time .sleep (2 )
144+
145+ return False , None
146+
147+ def load_s3_url (self , s3_url ):
148+ retries = 0
149+
150+ # The connections by boto3 cause ResourceWarning. Known
151+ # issue: https://github.com/boto/boto3/issues/454
152+ s3 = boto3 .client ('s3' )
153+
154+ while True :
155+ try :
156+ bucket_name = s3_url .split ("/" )[2 ]
157+ object_name = s3_url .split ("s3://" + bucket_name + "/" )[- 1 ]
158+ s3_response_object = s3 .get_object (Bucket = bucket_name , Key = object_name )
159+ img = s3_response_object ['Body' ].read ()
160+ imgbuffer = np .frombuffer (img , dtype = 'uint8' )
161+ if self .check_image and not self .check_image_buffer (imgbuffer ):
162+ print ("IMAGE ERROR: " , s3_url )
163+ return False , None
164+
165+ return True , img
166+ except :
167+ if retries >= self .n_download_retries :
168+ break
169+ print ("WARNING: Retrying object:" , s3_url )
170+ retries += 1
171+ time .sleep (2 )
86172
173+ print ("S3 ERROR:" , s3_url )
87174 return False , None
88175
89176 def validate (self ):
90177
91178 self .header = list (self .df .columns .values )
92179
93- if self .header [0 ] != HEADER_PATH :
180+ if self .header [0 ] not in [ HEADER_PATH , HEADER_URL , HEADER_S3_URL ] :
94181 raise Exception ("Error with CSV file field: filename. Must be first field" )
95182
96183class ImageLoader (ParallelLoader .ParallelLoader ):
0 commit comments