Skip to content

Commit a6a1ae0

Browse files
committed
feat: add color selector
1 parent 288ef2c commit a6a1ae0

8 files changed

Lines changed: 250 additions & 21 deletions

File tree

character.py

Lines changed: 179 additions & 0 deletions
Large diffs are not rendered by default.

config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(self, conf_path: str, graph_path: str = None, model_path: str = Non
2424
self.logger_tag = self.sys_cf['System'].get('LoggerTag')
2525
self.logger_tag = self.logger_tag if self.logger_tag else "coriander"
2626
self.logger = logging.getLogger(self.logger_tag)
27+
self.static_path = self.sys_cf['System'].get('StaticPath')
28+
self.static_path = self.static_path if self.static_path else 'static'
2729
self.use_default_authorization = False
2830
self.authorization = None
2931
self.init_logger()

demo.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515
DEFAULT_HOST = "localhost"
1616

1717

18-
def _image(_path, model_type=None, model_site=None):
18+
def _image(_path, model_type=None, model_site=None, need_color=None):
1919
with open(_path, "rb") as f:
2020
img_bytes = f.read()
2121

2222
b64 = base64.b64encode(img_bytes).decode()
2323
return {
2424
'image': b64,
2525
'model_type': model_type,
26-
'model_site': model_site
26+
'model_site': model_site,
27+
'need_color': need_color,
2728
}
2829

2930

@@ -225,13 +226,14 @@ def press_testing(self, image_list: dict, model_type=None, model_site=None):
225226
_path.split('_')[0].lower(): _image(
226227
os.path.join(path, _path),
227228
model_type=None,
228-
model_site=None
229+
model_site=None,
230+
need_color=None,
229231
)
230232
for i, _path in enumerate(path_list)
231233
if i < 1000
232234
}
233-
print(batch)
234-
# NoAuth(DEFAULT_HOST, ServerType.TORNADO).press_testing(batch)
235+
# print(batch)
236+
# NoAuth(DEFAULT_HOST, ServerType.TORNADO).local_iter(batch)
235237
# NoAuth(DEFAULT_HOST, ServerType.FLASK).local_iter(batch)
236238
# NoAuth(DEFAULT_HOST, ServerType.SANIC).local_iter(batch)
237239
GoogleRPC(DEFAULT_HOST).local_iter(batch, model_site=None, model_type=None)

grpc.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ message PredictRequest {
1010
string model_name = 3;
1111
string model_type = 4;
1212
string model_site = 5;
13+
string need_color = 6;
1314
}
1415

1516
message PredictResult {

grpc_pb2.py

Lines changed: 14 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

grpc_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def predict(self, request, context):
4141
if not interface:
4242
logger.info('Service is not ready!')
4343
return {"result": "", "success": False, "code": 999}
44-
image_batch, status = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
44+
image_batch, status = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, color=request.need_color)
4545

4646
if not image_batch:
4747
return grpc_pb2.PredictResult(result="", success=status['success'], code=status['code'])

tornado_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def post(self):
7979
model_site = ParamUtils.filter(data.get('model_site'))
8080
model_name = ParamUtils.filter(data.get('model_name'))
8181
split_char = ParamUtils.filter(data.get('split_char'))
82+
need_color = ParamUtils.filter(data.get('need_color'))
8283

8384
if not bytes_batch:
8485
logger.error('Type[{}] - Site[{}] - Response[{}] - {} ms'.format(
@@ -101,7 +102,7 @@ def post(self):
101102

102103
split_char = split_char if 'split_char' in data else interface.model_conf.split_char
103104

104-
image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
105+
image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, color=need_color)
105106

106107
if not image_batch:
107108
logger.error('Type[{}] - Site[{}] - Response[{}] - {} ms'.format(
@@ -138,6 +139,7 @@ def post(self):
138139
model_site = ParamUtils.filter(data.get('model_site'))
139140
model_name = ParamUtils.filter(data.get('model_name'))
140141
split_char = ParamUtils.filter(data.get('split_char'))
142+
need_color = ParamUtils.filter(data.get('need_color'))
141143

142144
bytes_batch, response = ImageUtils.get_bytes_batch(data['image'])
143145

@@ -162,7 +164,7 @@ def post(self):
162164

163165
split_char = split_char if 'split_char' in data else interface.model_conf.split_char
164166

165-
image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch)
167+
image_batch, response = ImageUtils.get_image_batch(interface.model_conf, bytes_batch, color=need_color)
166168

167169
if not image_batch:
168170
logger.error('[{}] - Size[{}] - Type[{}] - Site[{}] - Response[{}] - {} ms'.format(

utils.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,59 @@ def get_bytes_batch(base64_img):
7070
return bytes_batch, response.SUCCESS
7171

7272
@staticmethod
73-
def get_image_batch(model: ModelConfig, bytes_batch):
73+
def get_image_batch(model: ModelConfig, bytes_batch, color=None):
7474
# Note that there are two return objects here.
7575
# 1.image_batch, 2.response
7676

7777
response = Response()
7878

79-
def load_image(image_bytes):
80-
data_stream = io.BytesIO(image_bytes)
81-
pil_image = PIL_Image.open(data_stream).convert('RGB')
82-
image = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2GRAY)
79+
hsv_map = {
80+
"blue": {
81+
"lower_hsv": np.array([100, 128, 46]),
82+
"high_hsv": np.array([124, 255, 255])
83+
},
84+
"red": {
85+
"lower_hsv": np.array([0, 128, 46]),
86+
"high_hsv": np.array([5, 255, 255])
87+
},
88+
"yellow": {
89+
"lower_hsv": np.array([15, 128, 46]),
90+
"high_hsv": np.array([34, 255, 255])
91+
},
92+
"green": {
93+
"lower_hsv": np.array([35, 128, 46]),
94+
"high_hsv": np.array([77, 255, 255])
95+
},
96+
"black": {
97+
"lower_hsv": np.array([0, 0, 0]),
98+
"high_hsv": np.array([180, 255, 46])
99+
}
100+
}
101+
102+
def separate_color(pil_image, color):
103+
hsv = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_BGR2HSV)
104+
lower_hsv = hsv_map[color]['lower_hsv']
105+
high_hsv = hsv_map[color]['high_hsv']
106+
mask = cv2.inRange(hsv, lowerb=lower_hsv, upperb=high_hsv)
107+
return mask
108+
109+
def load_image(image_bytes, color=None):
110+
111+
if color and color in ['red', 'blue', 'black', 'green', 'yellow']:
112+
image = np.asarray(bytearray(image_bytes), dtype="uint8")
113+
image = cv2.imdecode(image, -1)
114+
image = separate_color(image, color)
115+
else:
116+
data_stream = io.BytesIO(image_bytes)
117+
pil_image = PIL_Image.open(data_stream).convert('RGB')
118+
image = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2GRAY)
83119
image = preprocessing(image, model.binaryzation, model.smooth, model.blur).astype(np.float32)
84120
image = cv2.resize(image, (model.resize[0], model.resize[1]))
85121
image = image.swapaxes(0, 1)
86122
return image[:, :, np.newaxis] / 255.
87123

88124
try:
89-
image_batch = [load_image(i) for i in bytes_batch]
125+
image_batch = [load_image(i, color=color) for i in bytes_batch]
90126
return image_batch, response.SUCCESS
91127
except OSError:
92128
return None, response.IMAGE_DAMAGE

0 commit comments

Comments
 (0)