-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsnippet.py
More file actions
165 lines (141 loc) · 5.3 KB
/
snippet.py
File metadata and controls
165 lines (141 loc) · 5.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import argparse
import imagehash
import json
import logging
import math
import numpy as np
import pandas as pd
import re
import time
from google.cloud import datastore, storage, vision
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument('project', help='The GCP project')
parser.add_argument('dataset', help='The dataset CSV path')
parser.add_argument("--export_json", type=bool, help='Export label result to JSON')
args = parser.parse_args()
SCORE_THRESHOLD = 0.7
IMAGE_FOLDER = 'images'
MAX_BATCH_SIZE = 16
BUCKET_NAME = 'vision-api-example'
ENTITY_TYPE = 'ImageLabel'
NAMESPACE = 'vision-api-example'
gcs_client = storage.Client(project=args.project)
bucket = gcs_client.get_bucket(BUCKET_NAME)
datastore_client = datastore.Client(project=args.project)
vision_client = vision.ImageAnnotatorClient()
def load_dataset(file_path):
_df = pd.read_csv(file_path)
_df['file_path'] = IMAGE_FOLDER + '/' + _df.file
_df['file_id'] = _df.apply(get_file_id, axis=1)
_df['image_hash'] = _df.apply(get_image_hash, axis=1)
keys = [datastore_client.key(ENTITY_TYPE, h, namespace=NAMESPACE) for h in _df['image_hash'].values]
if keys:
image_labels = datastore_client.get_multi(keys)
if image_labels:
existing_hashes = [ent.key.id_or_name for ent in image_labels]
_df = _df.loc[~_df['image_hash'].isin(existing_hashes)]
_df.set_index('file_id')
return _df
def get_file_id(row):
p = re.compile('.+/([A-Za-z0-9-_]+).[A-Za-z]+')
m = p.match(row.file_path)
_file_id = None
if m:
_file_id = m.group(1)
return _file_id
def get_image_hash(row):
img = Image.open(row.file_path)
return str(imagehash.average_hash(img))
def filter_labels(labels):
filtered = {}
if labels:
for l in labels:
if l.score >= SCORE_THRESHOLD:
filtered[l.description] = float('{0:.2f}'.format(l.score))
return filtered
# Get image labels from Vision API using batch request
def get_label(_df):
files = _df.file_path.values
file_index = {}
requests = []
idx = 0
cols = ['file_path', 'image_labels']
label_df = pd.DataFrame(columns=cols)
for f in files:
try:
with open(f, 'rb') as image_file:
request_data = {
'image': {'content': image_file.read()},
'features': [{'type': vision.enums.Feature.Type.LABEL_DETECTION}]
}
requests.append(request_data)
file_index[f] = idx
idx += 1
except Exception as e:
logger.error(e)
if requests:
logger.info('Detect labels for image batch %s' % files)
batch_response = vision_client.batch_annotate_images(requests)
if batch_response.responses:
for k, v in file_index.items():
labels = filter_labels(batch_response.responses[v].label_annotations)
label_df = label_df.append(pd.Series([k, labels], index=cols), ignore_index=True)
return label_df
# Upload image to GCS
def upload_to_gcs(_file_id, _file_path):
gcs_filename = 'gs://%s/%s' % (BUCKET_NAME, _file_id)
try:
blob = bucket.blob(_file_id)
blob.upload_from_filename(filename=_file_path)
return gcs_filename
except Exception as e:
logging.error('Error uploading file %s' % gcs_filename)
raise e
# Store images to GCS, export result to JSON. Return entity to be persisted in batch
def store_data(row):
try:
gcs_file = upload_to_gcs(row.file_id, row.file_path)
entity_key = datastore_client.key(ENTITY_TYPE, row.image_hash, namespace=NAMESPACE)
image_label = datastore.Entity(key=entity_key, exclude_from_indexes=['gcs_file', 'labels'])
image_label.update({
'gcs_file': gcs_file,
'labels': row.image_labels
})
if args.export_json:
with open('output/%s.json' % row.file_id, 'w') as fp:
label_data = {
'imageHash': row.image_hash,
'labels': row.image_labels
}
json.dump(label_data, fp, indent=2, separators=(',', ': '))
return image_label
except Exception as e:
logging.error('Error storing file data: %s' % row.file_id)
# Store images to GCS, export result to JSON and persist image data to datastore
def store_image_label(_df):
entities = [store_data(row) for index, row in _df.iterrows()]
if entities:
entity_partitions = math.ceil(len(entities) / 100)
entity_batches = np.array_split(entities, entity_partitions)
for b in entity_batches:
datastore_client.put_multi(b.tolist())
else:
logger.info('No entities to store.')
def tag_images(_df):
label_df = get_label(_df)
_df = pd.merge(_df, label_df, left_on='file_path', right_on='file_path')
store_image_label(_df)
if __name__ == "__main__":
starttime = time.time()
df = load_dataset(args.dataset)
if len(df) > 0:
partitions = math.ceil(len(df) / MAX_BATCH_SIZE)
df_batches = np.array_split(df, partitions)
for d in df_batches:
tag_images(d)
endtime = time.time()
elapsed = endtime - starttime
print('Elapsed %ss' % int(round(elapsed)))