|
8 | 8 |
|
9 | 9 | import numpy |
10 | 10 | import serial |
| 11 | +import scipy.signal |
11 | 12 |
|
12 | 13 | import matplotlib |
13 | 14 | from matplotlib import pyplot as plt |
@@ -75,62 +76,115 @@ def create_interactive(): |
75 | 76 | win = Gtk.Window() |
76 | 77 | win.connect("delete-event", Gtk.main_quit) |
77 | 78 | win.set_default_size(400, 300) |
78 | | - win.set_title("Embedding in GTK") |
| 79 | + win.set_title("On-sensor Audio Classification") |
79 | 80 |
|
80 | | - f = matplotlib.figure.Figure(figsize=(5, 4), dpi=100) |
81 | | - ax = f.add_subplot(111) |
82 | | - t = numpy.arange(0.0, 3.0, 0.01) |
83 | | - s = numpy.sin(2*numpy.pi*t) |
84 | | - |
85 | | - #ax.plot(t, s) |
| 81 | + fig, (ax, text_ax) = plt.subplots(1, 2) |
86 | 82 |
|
87 | 83 | sw = Gtk.ScrolledWindow() |
88 | 84 | win.add(sw) |
89 | 85 | # A scrolled window border goes outside the scrollbars and viewport |
90 | 86 | sw.set_border_width(10) |
91 | 87 |
|
92 | | - canvas = FigureCanvas(f) # a Gtk.DrawingArea |
93 | | - canvas.set_size_request(800, 600) |
| 88 | + canvas = FigureCanvas(fig) # a Gtk.DrawingArea |
| 89 | + canvas.set_size_request(200, 400) |
94 | 90 | sw.add_with_viewport(canvas) |
95 | 91 |
|
96 | | - predictions = numpy.random.random(10) |
97 | | - rects = ax.bar(numpy.arange(len(predictions)), predictions, align='center', alpha=0.5) |
| 92 | + prediction_threshold = 0.35 |
98 | 93 |
|
99 | | - return win, f, ax, rects |
| 94 | + # Plots |
| 95 | + predictions = numpy.zeros(11) |
| 96 | + tt = numpy.arange(len(predictions)) |
| 97 | + rects = ax.barh(tt, predictions, align='center', alpha=0.5) |
| 98 | + ax.set_yticks(tt) |
| 99 | + ax.set_yticklabels(classnames) |
| 100 | + ax.set_xlim(0, 1) |
100 | 101 |
|
101 | | -def update_plot(ser, ax, fig, rects): |
102 | | - raw = ser.readline() |
103 | | - line = raw.decode('utf-8') |
104 | | - predictions = parse_input(line) |
| 102 | + ax.axvline(prediction_threshold) |
| 103 | + ax.yaxis.set_ticks_position('right') |
| 104 | + |
| 105 | + # Text |
| 106 | + text_ax.axes.get_xaxis().set_visible(False) |
| 107 | + text_ax.axes.get_yaxis().set_visible(False) |
| 108 | + |
| 109 | + text = text_ax.text(0.5, 0.2, "Unknown", |
| 110 | + horizontalalignment='center', |
| 111 | + verticalalignment='center', |
| 112 | + fontsize=32, |
| 113 | + ) |
| 114 | + |
| 115 | + def emwa(new, prev, alpha): |
| 116 | + return alpha * new + (1 - alpha) * prev |
| 117 | + |
| 118 | + prev = predictions |
| 119 | + alpha = 0.2 # smoothing coefficient |
| 120 | + |
| 121 | + window = numpy.zeros(shape=(4, 11)) |
| 122 | + |
| 123 | + from scipy.ndimage.interpolation import shift |
| 124 | + |
| 125 | + def update_plot(predictions): |
| 126 | + |
| 127 | + if len(predictions) < 10: |
| 128 | + return |
| 129 | + |
| 130 | + # add unknown class |
| 131 | + predictions = numpy.concatenate([predictions, [0.0]]) |
| 132 | + |
| 133 | + window[:, :] = numpy.roll(window, 1, axis=0) |
| 134 | + window[0, :] = predictions |
| 135 | + |
| 136 | + predictions = numpy.mean(window, axis=0) |
105 | 137 |
|
106 | | - if predictions: |
107 | 138 | best_p = numpy.max(predictions) |
108 | 139 | best_c = numpy.argmax(predictions) |
109 | | - name = classnames[best_c] |
110 | | - if best_p >= 0.35: |
111 | | - print('p', name, best_p) |
| 140 | + if best_p <= prediction_threshold: |
| 141 | + best_c = 10 |
| 142 | + best_p = 0.0 |
112 | 143 |
|
113 | 144 | for rect, h in zip(rects, predictions): |
114 | | - rect.set_height(h) |
| 145 | + rect.set_width(h) |
| 146 | + |
| 147 | + name = classnames[best_c] |
| 148 | + text.set_text(name) |
| 149 | + |
| 150 | + fig.tight_layout() |
| 151 | + fig.canvas.draw() |
| 152 | + |
| 153 | + return win, update_plot |
| 154 | + |
| 155 | +def fetch_predictions(ser): |
| 156 | + raw = ser.readline() |
| 157 | + line = raw.decode('utf-8') |
| 158 | + predictions = parse_input(line) |
| 159 | + return predictions |
115 | 160 |
|
116 | | - fig.canvas.draw() |
117 | 161 |
|
118 | | - return True |
119 | 162 |
|
120 | 163 | def main(): |
121 | 164 | test_parse_preds() |
122 | 165 |
|
123 | 166 | device = '/dev/ttyACM1' |
124 | 167 | baudrate = 115200 |
125 | 168 |
|
126 | | - window, fig, ax, rects = create_interactive() |
| 169 | + window, plot = create_interactive() |
127 | 170 | window.show_all() |
128 | 171 |
|
| 172 | + def update(ser): |
| 173 | + try: |
| 174 | + preds = fetch_predictions(ser) |
| 175 | + except Exception as e: |
| 176 | + print('error', e) |
| 177 | + return True |
| 178 | + |
| 179 | + if preds is not None: |
| 180 | + plot(preds) |
| 181 | + return True |
| 182 | + |
129 | 183 | with serial.Serial(device, baudrate, timeout=0.1) as ser: |
130 | 184 | # avoid reading stale data |
131 | 185 | thrash = ser.read(10000) |
132 | | - |
133 | | - GLib.timeout_add(200.0, update_plot, ser, ax, fig, rects) |
| 186 | + |
| 187 | + GLib.timeout_add(200.0, update, ser) |
134 | 188 |
|
135 | 189 | Gtk.main() # WARN: blocking |
136 | 190 |
|
|
0 commit comments