|
23 | 23 | import logging
|
24 | 24 |
|
25 | 25 | import numpy as np
|
26 |
| -import tensorflow as tf |
| 26 | +from tensorflow.lite.python.interpreter import Interpreter |
| 27 | +import cv2 |
27 | 28 |
|
28 | 29 | logger = logging.getLogger(__name__)
|
29 | 30 |
|
30 | 31 | class CNNClassifier(object):
|
31 | 32 | def __init__(self, model_file, label_file, input_layer="input", output_layer="final_result", input_height=128, input_width=128, input_mean=127.5, input_std=127.5):
|
32 |
| - self._graph = self.load_graph(model_file) |
| 33 | + logger.info(model_file) |
| 34 | + self._interpreter = Interpreter(model_path=model_file) |
| 35 | + self._interpreter.set_num_threads(4) |
| 36 | + self._interpreter.allocate_tensors() |
33 | 37 | self._labels = self.load_labels(label_file)
|
34 |
| - self.input_height=input_height |
35 |
| - self.input_width=input_width |
36 |
| - input_name = "import/" + input_layer |
37 |
| - output_name = "import/" + output_layer |
38 |
| - self._input_operation = self._graph.get_operation_by_name(input_name) |
39 |
| - self._output_operation = self._graph.get_operation_by_name(output_name) |
40 |
| - self._session = tf.Session(graph=self._graph) |
41 |
| - self._graph_norm = tf.Graph() |
42 |
| - with self._graph_norm.as_default(): |
43 |
| - image_mat = tf.placeholder(tf.float32, None, name="image_rgb_in") |
44 |
| - float_caster = tf.cast(image_mat, tf.float32) |
45 |
| - dims_expander = tf.expand_dims(float_caster, 0) |
46 |
| - resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) |
47 |
| - normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std], name="image_norm_out") |
48 |
| - self._input_operation_norm = self._graph_norm.get_operation_by_name("image_rgb_in") |
49 |
| - self._output_operation_norm = self._graph_norm.get_operation_by_name("image_norm_out") |
50 |
| - self._sess_norm = tf.Session(graph=self._graph_norm) |
| 38 | + self._input_details = self._interpreter.get_input_details() |
| 39 | + self._output_details = self._interpreter.get_output_details() |
| 40 | + self._input_height=self._input_details[0]['shape'][1] |
| 41 | + self._input_width=self._input_details[0]['shape'][2] |
| 42 | + self._floating_model = (self._input_details[0]['dtype'] == np.float32) |
| 43 | + #input_name = "import/" + input_layer |
| 44 | + #output_name = "import/" + output_layer |
| 45 | + #self._input_operation = self._graph.get_operation_by_name(input_name) |
| 46 | + #self._output_operation = self._graph.get_operation_by_name(output_name) |
| 47 | + #self._session = tf.compat.v1.Session(graph=self._graph) |
| 48 | + #self._graph_norm = tf.Graph() |
| 49 | + #with self._graph_norm.as_default(): |
| 50 | + # image_mat = tf.compat.v1.placeholder(tf.float32, None, name="image_rgb_in") |
| 51 | + # float_caster = tf.cast(image_mat, tf.float32) |
| 52 | + # dims_expander = tf.expand_dims(float_caster, 0) |
| 53 | + # resized = tf.compat.v1.image.resize_bilinear(dims_expander, [input_height, input_width]) |
| 54 | + # normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std], name="image_norm_out") |
| 55 | + # self._input_operation_norm = self._graph_norm.get_operation_by_name("image_rgb_in") |
| 56 | + # self._output_operation_norm = self._graph_norm.get_operation_by_name("image_norm_out") |
| 57 | + #self._sess_norm = tf.Session(graph=self._graph_norm) |
51 | 58 |
|
52 | 59 | def close(self):
|
53 |
| - self._session.close() |
54 |
| - self._sess_norm.close() |
55 |
| - |
56 |
| - def load_graph(self, model_file): |
57 |
| - graph = tf.Graph() |
58 |
| - graph_def = tf.GraphDef() |
59 |
| - |
60 |
| - with open(model_file, "rb") as f: |
61 |
| - graph_def.ParseFromString(f.read()) |
62 |
| - with graph.as_default(): |
63 |
| - tf.import_graph_def(graph_def) |
64 |
| - |
65 |
| - return graph |
66 |
| - |
67 |
| - def read_tensor_from_image_file(self, file_name, input_height=299, input_width=299, input_mean=0, input_std=255): |
68 |
| - input_name = "file_reader" |
69 |
| - output_name = "normalized" |
70 |
| - |
71 |
| - file_reader = tf.read_file(file_name, input_name) |
72 |
| - |
73 |
| - if file_name.endswith(".png"): |
74 |
| - image_reader = tf.image.decode_png(file_reader, channels=3, name='png_reader') |
75 |
| - elif file_name.endswith(".gif"): |
76 |
| - image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name='gif_reader')) |
77 |
| - elif file_name.endswith(".bmp"): |
78 |
| - image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') |
79 |
| - else: |
80 |
| - image_reader = tf.image.decode_jpeg(file_reader, channels=3, name='jpeg_reader') |
81 |
| - |
82 |
| - float_caster = tf.cast(image_reader, tf.float32) |
83 |
| - dims_expander = tf.expand_dims(float_caster, 0); |
84 |
| - resized = tf.image.resize_bilinear(dims_expander, [self.input_height, self.input_width]) |
85 |
| - normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) |
86 |
| - sess = tf.Session() |
87 |
| - |
88 |
| - result = sess.run(normalized) |
89 |
| - sess.close() |
| 60 | + pass |
| 61 | + #self._session.close() |
| 62 | + #self._sess_norm.close() |
| 63 | + |
| 64 | + #def load_graph(self, model_file): |
| 65 | + # graph = tf.Graph() |
| 66 | + # graph_def = tf.compat.v1.GraphDef() |
| 67 | + # |
| 68 | + # with open(model_file, "rb") as f: |
| 69 | + # graph_def.ParseFromString(f.read()) |
| 70 | + # with graph.as_default(): |
| 71 | + # tf.import_graph_def(graph_def) |
| 72 | + # |
| 73 | + # return graph |
| 74 | + # |
| 75 | + #def read_tensor_from_image_file(self, file_name, input_height=299, input_width=299, input_mean=0, input_std=255): |
| 76 | + # input_name = "file_reader" |
| 77 | + # output_name = "normalized" |
| 78 | + # |
| 79 | + # file_reader = tf.read_file(file_name, input_name) |
| 80 | + # |
| 81 | + # if file_name.endswith(".png"): |
| 82 | + # image_reader = tf.image.decode_png(file_reader, channels=3, name='png_reader') |
| 83 | + # elif file_name.endswith(".gif"): |
| 84 | + # image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name='gif_reader')) |
| 85 | + # elif file_name.endswith(".bmp"): |
| 86 | + # image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') |
| 87 | + # else: |
| 88 | + # image_reader = tf.image.decode_jpeg(file_reader, channels=3, name='jpeg_reader') |
| 89 | + # |
| 90 | + # float_caster = tf.cast(image_reader, tf.float32) |
| 91 | + # dims_expander = tf.expand_dims(float_caster, 0); |
| 92 | + # resized = tf.image.resize_bilinear(dims_expander, [self.input_height, self.input_width]) |
| 93 | + # normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) |
| 94 | + # sess = tf.Session() |
| 95 | + # |
| 96 | + # result = sess.run(normalized) |
| 97 | + # sess.close() |
| 98 | + # |
| 99 | + # return result |
| 100 | + # |
| 101 | + #def read_tensor_from_image_mat(self, image_mat, input_height=299, input_width=299, input_mean=0, input_std=255): |
| 102 | + # result = self._sess_norm.run(self._output_operation_norm.outputs[0], {self._input_operation_norm.outputs[0]: image_mat}) |
| 103 | + # return result |
| 104 | + def read_tensor_from_image_mat(self, image_mat, input_height=299, input_width=299, input_mean=0, input_std=255): |
| 105 | + frame_rgb = cv2.cvtColor(image_mat, cv2.COLOR_BGR2RGB) |
| 106 | + frame_resized = cv2.resize(frame_rgb, (self._input_width, self._input_height)) |
| 107 | + input_data = np.expand_dims(frame_resized, axis=0) |
90 | 108 |
|
91 |
| - return result |
| 109 | + # Normalize pixel values if using a floating model (i.e. if model is non-quantized) |
| 110 | + if self._floating_model: |
| 111 | + input_mean = 127.5 |
| 112 | + input_std = 127.5 |
| 113 | + input_data = (np.float32(input_data) - input_mean) / input_std |
92 | 114 |
|
93 |
| - def read_tensor_from_image_mat(self, image_mat, input_height=299, input_width=299, input_mean=0, input_std=255): |
94 |
| - result = self._sess_norm.run(self._output_operation_norm.outputs[0], {self._input_operation_norm.outputs[0]: image_mat}) |
95 |
| - return result |
| 115 | + return input_data |
96 | 116 |
|
97 | 117 | def load_labels(self, label_file):
|
98 |
| - label = [] |
99 |
| - proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() |
100 |
| - for l in proto_as_ascii_lines: |
101 |
| - label.append(l.rstrip()) |
102 |
| - return label |
| 118 | + labels = [] |
| 119 | + with open(label_file, 'r') as f: |
| 120 | + labels = [line.strip() for line in f.readlines()] |
| 121 | + return labels |
103 | 122 |
|
104 | 123 | def classify_image(self,
|
105 | 124 | image_file_or_mat,
|
106 | 125 | top_results=3):
|
107 |
| - t = None |
108 |
| - if isinstance(image_file_or_mat, str): |
109 |
| - t = self.read_tensor_from_image_file(file_name=image_file_or_mat) |
110 |
| - else: |
111 |
| - t = self.read_tensor_from_image_mat(image_file_or_mat) |
112 |
| - |
113 |
| - results = self._session.run(self._output_operation.outputs[0], |
114 |
| - {self._input_operation.outputs[0]: t}) |
115 |
| - |
116 |
| - top_results = min(top_results, len(self._labels)) |
117 |
| - results = np.squeeze(results) |
118 |
| - results_idx = np.argpartition(results, -top_results)[-top_results:] |
119 |
| - results_idx = np.flip(results_idx[np.argsort(results[results_idx])], axis=0) |
120 |
| - pairs = [(self._labels[i], results[i]) for i in results_idx] |
| 126 | + input_image = None |
| 127 | + #if isinstance(image_file_or_mat, str): |
| 128 | + # t = self.read_tensor_from_image_file(file_name=image_file_or_mat) |
| 129 | + #else: |
| 130 | + input_image = self.read_tensor_from_image_mat(image_file_or_mat) |
| 131 | + |
| 132 | + logger.info("classify.0") |
| 133 | + self._interpreter.set_tensor(self._input_details[0]['index'], input_image) |
| 134 | + self._interpreter.invoke() |
| 135 | + logger.info("classify.1") |
| 136 | + scores = self._interpreter.get_tensor(self._output_details[0]['index'])[0] # Bounding box coordinates of detected objects |
| 137 | + #logger.info("classify.2") |
| 138 | + #classes = self._interpreter.get_tensor(self._output_details[1]['index'])[0] # Class index of detected objects |
| 139 | + #logger.info("classify.3") |
| 140 | + #scores = self._interpreter.get_tensor(self._output_details[2]['index'])[0] # Confidence of detected objects |
| 141 | + #logger.info("classify.4") |
| 142 | + |
| 143 | + #pairs = [(classes[i], scores[i], boxes[i]) for i in range(0, len(classes))] |
| 144 | + pairs = [] |
| 145 | + for i in range(0, len(scores)): |
| 146 | + if scores[i] > 0.5: |
| 147 | + object_name = self._labels[i] |
| 148 | + pairs.append((object_name, scores[i])) |
| 149 | + |
| 150 | + logger.info(str(pairs)) |
121 | 151 | return pairs
|
0 commit comments