Skip to content

Commit 6068801

Browse files
committed
updated SpeedTrigger/FreezeTrigger and added example SpeedExperiment
1 parent 08975df commit 6068801

File tree

2 files changed

+121
-38
lines changed

2 files changed

+121
-38
lines changed

experiments/custom/experiments.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from functools import partial
1212
from collections import Counter
1313
from experiments.custom.stimulus_process import ClassicProtocolProcess, SimpleProtocolProcess,Timer, ExampleProtocolProcess
14-
from experiments.custom.triggers import ScreenTrigger, RegionTrigger, OutsideTrigger, DirectionTrigger
14+
from experiments.custom.triggers import ScreenTrigger, RegionTrigger, OutsideTrigger, DirectionTrigger, SpeedTrigger
1515
from utils.plotter import plot_triggers_response
1616
from utils.analysis import angle_between_vectors
1717
from experiments.custom.stimulation import show_visual_stim_img,laser_switch
@@ -137,6 +137,80 @@ def get_trial(self):
137137
EXP_COMPLETION = 10
138138

139139

140+
141+
class SpeedExperiment:
142+
"""
143+
Simple class to contain all of the experiment properties
144+
Uses multiprocess to ensure the best possible performance and
145+
to showcase that it is possible to work with any type of equipment, even timer-dependent
146+
"""
147+
def __init__(self):
148+
self.experiment_finished = False
149+
self._threshold = 10
150+
self._event = None
151+
self._current_trial = None
152+
self._event_count = 0
153+
self._trigger = SpeedTrigger(threshold = self._threshold,bodypart= 'tailroot', timewindow_len= 5)
154+
self._exp_timer = Timer(600)
155+
156+
def check_skeleton(self, frame, skeleton):
157+
"""
158+
Checking each passed animal skeleton for a pre-defined set of conditions
159+
Outputting the visual representation, if exist
160+
Advancing trials according to inherent logic of an experiment
161+
:param frame: frame, on which animal skeleton was found
162+
:param skeleton: skeleton, consisting of multiple joints of an animal
163+
"""
164+
self.check_exp_timer() # checking if experiment is still on
165+
166+
if not self.experiment_finished:
167+
result, response = self._trigger.check_skeleton(skeleton=skeleton)
168+
plot_triggers_response(frame, response)
169+
if result:
170+
laser_switch(True)
171+
self._event_count += 1
172+
print(self._event_count)
173+
print('Light on')
174+
175+
else:
176+
laser_switch(False)
177+
print('Light off')
178+
179+
return result, response
180+
181+
182+
def check_exp_timer(self):
183+
"""
184+
Checking the experiment timer
185+
"""
186+
if not self._exp_timer.check_timer():
187+
print("Experiment is finished")
188+
print("Time ran out.")
189+
self.stop_experiment()
190+
191+
def start_experiment(self):
192+
"""
193+
Start the experiment
194+
"""
195+
if not self.experiment_finished:
196+
self._exp_timer.start()
197+
198+
def stop_experiment(self):
199+
"""
200+
Stop the experiment and reset the timer
201+
"""
202+
self.experiment_finished = True
203+
print('Experiment completed!')
204+
self._exp_timer.reset()
205+
# don't forget to stop the laser for safety!
206+
laser_switch(False)
207+
208+
def get_trial(self):
209+
"""
210+
Check which trial is going on right now
211+
"""
212+
return self._current_trial
213+
140214
class FirstExperiment:
141215
def __init__(self):
142216
self.experiment_finished = False

experiments/custom/triggers.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from utils.analysis import angle_between_vectors, calculate_distance, EllipseROI, RectangleROI
1111
from utils.configloader import RESOLUTION
12-
12+
from collections import deque
13+
import numpy as np
1314
"""Single posture triggers"""
1415

1516
class HeaddirectionROITrigger:
@@ -333,15 +334,20 @@ def check_skeleton(self, skeleton: dict):
333334

334335
class FreezeTrigger:
335336
"""
336-
Trigger to check if animal is in freezing state
337+
Trigger to check if animal is moving below a certain speed
337338
"""
338-
def __init__(self, threshold: int, debug: bool = False):
339+
def __init__(self, threshold: int, bodypart: str, timewindow_len:int = 2, debug: bool = False):
339340
"""
340341
Initializing trigger with given threshold
341342
:param threshold: int in pixel how much of a movement does not count
342-
For example threshold of 5 would mean that all movements less then 5 pixels would be ignored
343+
:param bodypart: str of body part in skeleton used for speed calculation
344+
For example threshold of 5 would mean that all movements more then 5 pixels in the last timewindow length frames
345+
would be ignored
343346
"""
347+
self._bodypart = bodypart
344348
self._threshold = threshold
349+
self._timewindow_len = timewindow_len
350+
self._timewindow = deque(maxlen= timewindow_len)
345351
self._skeleton = None
346352
self._debug = debug # not used in this trigger
347353

@@ -354,45 +360,52 @@ def check_skeleton(self, skeleton: dict):
354360
"""
355361
# choosing a point to draw near the skeleton
356362
org_point = skeleton[list(skeleton.keys())[0]]
357-
joint_moved = []
363+
joint_moved = 0
364+
358365
if self._skeleton is None:
359366
result = False
360-
text = 'Not freezing'
367+
text = '...'
361368
self._skeleton = skeleton
362369
else:
363-
for joint in skeleton:
364-
joint_travel = calculate_distance(skeleton[joint], self._skeleton[joint])
365-
joint_moved.append(abs(joint_travel) <= self._threshold)
366-
if all(joint_moved):
370+
joint_travel = calculate_distance(skeleton[self._bodypart], self._skeleton[self._bodypart])
371+
print(joint_travel)
372+
self._timewindow.append(joint_travel)
373+
print(self._timewindow)
374+
if len(self._timewindow) == self._timewindow_len:
375+
joint_moved = np.sum(self._timewindow)
376+
print(joint_moved)
377+
378+
if abs(joint_moved) <= self._threshold:
367379
result = True
368380
text = 'Freezing'
369381
else:
370382
result = False
371-
text = 'Not freezing'
372-
self._skeleton = skeleton
373-
383+
text = 'Not Freezing'
384+
self._skeleton = skeleton
374385
color = (0, 255, 0) if result else (0, 0, 255)
375386
response_body = {'plot': {'text': dict(text=text,
376387
org=org_point,
377388
color=color)}}
378389
response = (result, response_body)
379-
return response
380390

391+
return response
381392

382393
class SpeedTrigger:
383394
"""
384395
Trigger to check if animal is moving above a certain speed
385396
"""
386-
def __init__(self, threshold: int, bodypart: str = 'any', debug: bool = False):
397+
def __init__(self, threshold: int, bodypart: str, timewindow_len:int = 2, debug: bool = False):
387398
"""
388399
Initializing trigger with given threshold
389400
:param threshold: int in pixel how much of a movement does not count
390-
:param bodypart: str or list of str, bodypart or list of bodyparts in skeleton to use for trigger,
391-
if "any" will check if any bodypart reaches treshold; default "any"
392-
For example threshold of 5 would mean that all movements less then 5 pixels would be ignored
401+
:param bodypart: str of body part in skeleton used for speed calculation
402+
For example threshold of 5 would mean that all movements less then 5 pixels in the last timewindow length frames
403+
would be ignored
393404
"""
394405
self._bodypart = bodypart
395406
self._threshold = threshold
407+
self._timewindow_len = timewindow_len
408+
self._timewindow = deque(maxlen= timewindow_len)
396409
self._skeleton = None
397410
self._debug = debug # not used in this trigger
398411

@@ -405,36 +418,32 @@ def check_skeleton(self, skeleton: dict):
405418
"""
406419
# choosing a point to draw near the skeleton
407420
org_point = skeleton[list(skeleton.keys())[0]]
408-
joint_moved = []
421+
joint_moved = 0
422+
409423
if self._skeleton is None:
410424
result = False
411-
text = 'First frame'
425+
text = '...'
412426
self._skeleton = skeleton
413427
else:
414-
if self._bodypart is "any":
415-
for joint in skeleton:
416-
joint_travel = calculate_distance(skeleton[joint], self._skeleton[joint])
417-
joint_moved.append(abs(joint_travel) >= self._threshold)
418-
419-
elif isinstance(self._bodypart, list):
420-
for joint in self._bodypart:
421-
joint_travel = calculate_distance(skeleton[joint], self._skeleton[joint])
422-
joint_moved.append(abs(joint_travel) >= self._threshold)
423-
else:
424-
joint_travel = calculate_distance(skeleton[self._bodypart], self._skeleton[self._bodypart])
425-
joint_moved.append(abs(joint_travel) >= self._threshold)
426-
427-
if all(joint_moved):
428+
joint_travel = calculate_distance(skeleton[self._bodypart], self._skeleton[self._bodypart])
429+
print(joint_travel)
430+
self._timewindow.append(joint_travel)
431+
print(self._timewindow)
432+
if len(self._timewindow) == self._timewindow_len:
433+
joint_moved = np.sum(self._timewindow)
434+
print(joint_moved)
435+
436+
if abs(joint_moved) >= self._threshold:
428437
result = True
429438
text = 'Running'
430439
else:
431440
result = False
432-
text = ''
433-
self._skeleton = skeleton
434-
441+
text = 'Not Running'
442+
self._skeleton = skeleton
435443
color = (0, 255, 0) if result else (0, 0, 255)
436444
response_body = {'plot': {'text': dict(text=text,
437445
org=org_point,
438446
color=color)}}
439447
response = (result, response_body)
448+
440449
return response

0 commit comments

Comments
 (0)