|
23 | 23 | import logging |
24 | 24 |
|
25 | 25 | import numpy as np |
26 | | -import tensorflow as tf |
| 26 | +from tensorflow.lite.python.interpreter import Interpreter |
| 27 | +import cv2 |
27 | 28 |
|
28 | 29 | logger = logging.getLogger(__name__) |
29 | 30 |
|
30 | 31 | class CNNClassifier(object): |
31 | 32 | def __init__(self, model_file, label_file, input_layer="input", output_layer="final_result", input_height=128, input_width=128, input_mean=127.5, input_std=127.5): |
32 | | - self._graph = self.load_graph(model_file) |
| 33 | + logger.info(model_file) |
| 34 | + self._interpreter = Interpreter(model_path=model_file) |
| 35 | + self._interpreter.set_num_threads(4) |
| 36 | + self._interpreter.allocate_tensors() |
33 | 37 | self._labels = self.load_labels(label_file) |
34 | | - self.input_height=input_height |
35 | | - self.input_width=input_width |
36 | | - input_name = "import/" + input_layer |
37 | | - output_name = "import/" + output_layer |
38 | | - self._input_operation = self._graph.get_operation_by_name(input_name) |
39 | | - self._output_operation = self._graph.get_operation_by_name(output_name) |
40 | | - self._session = tf.Session(graph=self._graph) |
41 | | - self._graph_norm = tf.Graph() |
42 | | - with self._graph_norm.as_default(): |
43 | | - image_mat = tf.placeholder(tf.float32, None, name="image_rgb_in") |
44 | | - float_caster = tf.cast(image_mat, tf.float32) |
45 | | - dims_expander = tf.expand_dims(float_caster, 0) |
46 | | - resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) |
47 | | - normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std], name="image_norm_out") |
48 | | - self._input_operation_norm = self._graph_norm.get_operation_by_name("image_rgb_in") |
49 | | - self._output_operation_norm = self._graph_norm.get_operation_by_name("image_norm_out") |
50 | | - self._sess_norm = tf.Session(graph=self._graph_norm) |
| 38 | + self._input_details = self._interpreter.get_input_details() |
| 39 | + self._output_details = self._interpreter.get_output_details() |
| 40 | + self._input_height=self._input_details[0]['shape'][1] |
| 41 | + self._input_width=self._input_details[0]['shape'][2] |
| 42 | + self._floating_model = (self._input_details[0]['dtype'] == np.float32) |
| 43 | + #input_name = "import/" + input_layer |
| 44 | + #output_name = "import/" + output_layer |
| 45 | + #self._input_operation = self._graph.get_operation_by_name(input_name) |
| 46 | + #self._output_operation = self._graph.get_operation_by_name(output_name) |
| 47 | + #self._session = tf.compat.v1.Session(graph=self._graph) |
| 48 | + #self._graph_norm = tf.Graph() |
| 49 | + #with self._graph_norm.as_default(): |
| 50 | + # image_mat = tf.compat.v1.placeholder(tf.float32, None, name="image_rgb_in") |
| 51 | + # float_caster = tf.cast(image_mat, tf.float32) |
| 52 | + # dims_expander = tf.expand_dims(float_caster, 0) |
| 53 | + # resized = tf.compat.v1.image.resize_bilinear(dims_expander, [input_height, input_width]) |
| 54 | + # normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std], name="image_norm_out") |
| 55 | + # self._input_operation_norm = self._graph_norm.get_operation_by_name("image_rgb_in") |
| 56 | + # self._output_operation_norm = self._graph_norm.get_operation_by_name("image_norm_out") |
| 57 | + #self._sess_norm = tf.Session(graph=self._graph_norm) |
51 | 58 |
|
52 | 59 | def close(self): |
53 | | - self._session.close() |
54 | | - self._sess_norm.close() |
55 | | - |
56 | | - def load_graph(self, model_file): |
57 | | - graph = tf.Graph() |
58 | | - graph_def = tf.GraphDef() |
59 | | - |
60 | | - with open(model_file, "rb") as f: |
61 | | - graph_def.ParseFromString(f.read()) |
62 | | - with graph.as_default(): |
63 | | - tf.import_graph_def(graph_def) |
64 | | - |
65 | | - return graph |
66 | | - |
67 | | - def read_tensor_from_image_file(self, file_name, input_height=299, input_width=299, input_mean=0, input_std=255): |
68 | | - input_name = "file_reader" |
69 | | - output_name = "normalized" |
70 | | - |
71 | | - file_reader = tf.read_file(file_name, input_name) |
72 | | - |
73 | | - if file_name.endswith(".png"): |
74 | | - image_reader = tf.image.decode_png(file_reader, channels=3, name='png_reader') |
75 | | - elif file_name.endswith(".gif"): |
76 | | - image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name='gif_reader')) |
77 | | - elif file_name.endswith(".bmp"): |
78 | | - image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') |
79 | | - else: |
80 | | - image_reader = tf.image.decode_jpeg(file_reader, channels=3, name='jpeg_reader') |
81 | | - |
82 | | - float_caster = tf.cast(image_reader, tf.float32) |
83 | | - dims_expander = tf.expand_dims(float_caster, 0); |
84 | | - resized = tf.image.resize_bilinear(dims_expander, [self.input_height, self.input_width]) |
85 | | - normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) |
86 | | - sess = tf.Session() |
87 | | - |
88 | | - result = sess.run(normalized) |
89 | | - sess.close() |
| 60 | + pass |
| 61 | + #self._session.close() |
| 62 | + #self._sess_norm.close() |
| 63 | + |
| 64 | + #def load_graph(self, model_file): |
| 65 | + # graph = tf.Graph() |
| 66 | + # graph_def = tf.compat.v1.GraphDef() |
| 67 | + # |
| 68 | + # with open(model_file, "rb") as f: |
| 69 | + # graph_def.ParseFromString(f.read()) |
| 70 | + # with graph.as_default(): |
| 71 | + # tf.import_graph_def(graph_def) |
| 72 | + # |
| 73 | + # return graph |
| 74 | + # |
| 75 | + #def read_tensor_from_image_file(self, file_name, input_height=299, input_width=299, input_mean=0, input_std=255): |
| 76 | + # input_name = "file_reader" |
| 77 | + # output_name = "normalized" |
| 78 | + # |
| 79 | + # file_reader = tf.read_file(file_name, input_name) |
| 80 | + # |
| 81 | + # if file_name.endswith(".png"): |
| 82 | + # image_reader = tf.image.decode_png(file_reader, channels=3, name='png_reader') |
| 83 | + # elif file_name.endswith(".gif"): |
| 84 | + # image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name='gif_reader')) |
| 85 | + # elif file_name.endswith(".bmp"): |
| 86 | + # image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') |
| 87 | + # else: |
| 88 | + # image_reader = tf.image.decode_jpeg(file_reader, channels=3, name='jpeg_reader') |
| 89 | + # |
| 90 | + # float_caster = tf.cast(image_reader, tf.float32) |
| 91 | + # dims_expander = tf.expand_dims(float_caster, 0); |
| 92 | + # resized = tf.image.resize_bilinear(dims_expander, [self.input_height, self.input_width]) |
| 93 | + # normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) |
| 94 | + # sess = tf.Session() |
| 95 | + # |
| 96 | + # result = sess.run(normalized) |
| 97 | + # sess.close() |
| 98 | + # |
| 99 | + # return result |
| 100 | + # |
| 101 | + #def read_tensor_from_image_mat(self, image_mat, input_height=299, input_width=299, input_mean=0, input_std=255): |
| 102 | + # result = self._sess_norm.run(self._output_operation_norm.outputs[0], {self._input_operation_norm.outputs[0]: image_mat}) |
| 103 | + # return result |
| 104 | + def read_tensor_from_image_mat(self, image_mat, input_height=299, input_width=299, input_mean=0, input_std=255): |
| 105 | + frame_rgb = cv2.cvtColor(image_mat, cv2.COLOR_BGR2RGB) |
| 106 | + frame_resized = cv2.resize(frame_rgb, (self._input_width, self._input_height)) |
| 107 | + input_data = np.expand_dims(frame_resized, axis=0) |
90 | 108 |
|
91 | | - return result |
| 109 | + # Normalize pixel values if using a floating model (i.e. if model is non-quantized) |
| 110 | + if self._floating_model: |
| 111 | + input_mean = 127.5 |
| 112 | + input_std = 127.5 |
| 113 | + input_data = (np.float32(input_data) - input_mean) / input_std |
92 | 114 |
|
93 | | - def read_tensor_from_image_mat(self, image_mat, input_height=299, input_width=299, input_mean=0, input_std=255): |
94 | | - result = self._sess_norm.run(self._output_operation_norm.outputs[0], {self._input_operation_norm.outputs[0]: image_mat}) |
95 | | - return result |
| 115 | + return input_data |
96 | 116 |
|
97 | 117 | def load_labels(self, label_file): |
98 | | - label = [] |
99 | | - proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() |
100 | | - for l in proto_as_ascii_lines: |
101 | | - label.append(l.rstrip()) |
102 | | - return label |
| 118 | + labels = [] |
| 119 | + with open(label_file, 'r') as f: |
| 120 | + labels = [line.strip() for line in f.readlines()] |
| 121 | + return labels |
103 | 122 |
|
104 | 123 | def classify_image(self, |
105 | 124 | image_file_or_mat, |
106 | 125 | top_results=3): |
107 | | - t = None |
108 | | - if isinstance(image_file_or_mat, str): |
109 | | - t = self.read_tensor_from_image_file(file_name=image_file_or_mat) |
110 | | - else: |
111 | | - t = self.read_tensor_from_image_mat(image_file_or_mat) |
112 | | - |
113 | | - results = self._session.run(self._output_operation.outputs[0], |
114 | | - {self._input_operation.outputs[0]: t}) |
115 | | - |
116 | | - top_results = min(top_results, len(self._labels)) |
117 | | - results = np.squeeze(results) |
118 | | - results_idx = np.argpartition(results, -top_results)[-top_results:] |
119 | | - results_idx = np.flip(results_idx[np.argsort(results[results_idx])], axis=0) |
120 | | - pairs = [(self._labels[i], results[i]) for i in results_idx] |
| 126 | + input_image = None |
| 127 | + #if isinstance(image_file_or_mat, str): |
| 128 | + # t = self.read_tensor_from_image_file(file_name=image_file_or_mat) |
| 129 | + #else: |
| 130 | + input_image = self.read_tensor_from_image_mat(image_file_or_mat) |
| 131 | + |
| 132 | + logger.info("classify.0") |
| 133 | + self._interpreter.set_tensor(self._input_details[0]['index'], input_image) |
| 134 | + self._interpreter.invoke() |
| 135 | + logger.info("classify.1") |
| 136 | + scores = self._interpreter.get_tensor(self._output_details[0]['index'])[0] # Bounding box coordinates of detected objects |
| 137 | + #logger.info("classify.2") |
| 138 | + #classes = self._interpreter.get_tensor(self._output_details[1]['index'])[0] # Class index of detected objects |
| 139 | + #logger.info("classify.3") |
| 140 | + #scores = self._interpreter.get_tensor(self._output_details[2]['index'])[0] # Confidence of detected objects |
| 141 | + #logger.info("classify.4") |
| 142 | + |
| 143 | + #pairs = [(classes[i], scores[i], boxes[i]) for i in range(0, len(classes))] |
| 144 | + pairs = [] |
| 145 | + for i in range(0, len(scores)): |
| 146 | + if scores[i] > 0.5: |
| 147 | + object_name = self._labels[i] |
| 148 | + pairs.append((object_name, scores[i])) |
| 149 | + |
| 150 | + logger.info(str(pairs)) |
121 | 151 | return pairs |
0 commit comments