Skip to content

Commit 4e3db27

Browse files
committed
refactor augmentation to be a plugin
1 parent 10f6278 commit 4e3db27

File tree

13 files changed

+117
-263
lines changed

13 files changed

+117
-263
lines changed

configuration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
# neural network architecture to use
7474
architecture_plugin="convolutional"
7575
overlapped_prefix="not_"
76+
augmentation_plugin="volume-noise-dc-reverse-invert"
7677

7778
# on what computer to do the computation
7879
default_where="local"

src/activations

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,11 @@ def main():
157157
def infer_step(isound):
158158
# HACK: get_data not guaranteed to return isounds in order
159159
fingerprints, _, sounds = D.get_data(
160-
FLAGS.batch_size, isound, model_settings,
161-
FLAGS.loss, FLAGS.overlapped_prefix,
162-
time_shift_tics, 'testing',
163-
model.use_audio, model.use_video, video_findfile)
160+
FLAGS.batch_size, isound, model_settings,
161+
FLAGS.loss, FLAGS.overlapped_prefix,
162+
time_shift_tics, 'testing',
163+
model.use_audio, model.use_video, video_findfile,
164+
None, None)
164165
hidden_activations, logits = thismodel(fingerprints, training=False)
165166
return fingerprints, sounds, logits, hidden_activations
166167

src/data.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,22 @@ def _get_data(q, o, how_many, offset, model_settings, loss, overlapped_prefix,
411411
shiftby, mode, use_audio, use_video, video_findfile,
412412
data_index, labels_list, np_rng, data_dir,
413413
audio_read_plugin, audio_read_plugin_kwargs,
414-
video_read_plugin, video_read_plugin_kwargs):
414+
video_read_plugin, video_read_plugin_kwargs,
415+
augmentation_plugin, augmentation_parameters):
415416
q.cancel_join_thread()
416417

417418
from lib import compute_background, load_audio_read_plugin, load_video_read_plugin
418419
load_audio_read_plugin(audio_read_plugin, audio_read_plugin_kwargs)
419420
load_video_read_plugin(video_read_plugin, video_read_plugin_kwargs)
420421
from lib import audio_read, video_read
421422

423+
if use_audio and mode=='training':
424+
import importlib
425+
sys.path.insert(0,os.path.dirname(augmentation_plugin))
426+
tmp = importlib.import_module(os.path.basename(augmentation_plugin))
427+
def augment(audio_slice, augmentation_parameters):
428+
return tmp.augment(audio_slice, augmentation_parameters)
429+
422430
while True:
423431
# Pick one of the partitions to choose sounds from.
424432
pick_deterministically = mode != 'training'
@@ -498,29 +506,8 @@ def _get_data(q, o, how_many, offset, model_settings, loss, overlapped_prefix,
498506
labels[i - offset, labels_list.index(root)] = target
499507
sounds[-1].append({k: v for k,v in overlapped_sound.items() if k!='overlaps'})
500508

501-
# augmentation
502509
if use_audio and mode=='training':
503-
volume_range = [float(x) for x in model_settings['augment_volume'].split(',')]
504-
noise_range = [float(x) for x in model_settings['augment_noise'].split(',')]
505-
dc_range = [float(x) for x in model_settings['augment_dc'].split(',')]
506-
reverse_bool = model_settings['augment_reverse'] == 'yes'
507-
invert_bool = model_settings['augment_invert'] == 'yes'
508-
if volume_range != [1,1]:
509-
volume_ranges = np.random.uniform(*volume_range, (nsounds,1,audio_nchannels))
510-
audio_slice *= volume_ranges
511-
if noise_range != [0,0]:
512-
noise_ranges = np.random.uniform(*noise_range, (nsounds,1,audio_nchannels))
513-
noises = np.random.normal(0, noise_ranges, audio_slice.shape)
514-
audio_slice += noises
515-
if dc_range != [0,0]:
516-
dc_ranges = np.random.uniform(*dc_range, (nsounds,1,audio_nchannels))
517-
audio_slice += dc_ranges
518-
if reverse_bool:
519-
ireverse = np.random.choice([False,True], nsounds)
520-
audio_slice[ireverse] = np.flip(audio_slice[ireverse], axis=1)
521-
if invert_bool:
522-
iinvert = np.random.choice([-1,1], (nsounds,1,1))
523-
audio_slice *= iinvert
510+
audio_slice = augment(audio_slice, augmentation_parameters)
524511

525512
if loss=='autoencoder':
526513
labels = audio_slice
@@ -533,7 +520,8 @@ def _get_data(q, o, how_many, offset, model_settings, loss, overlapped_prefix,
533520
q.put([video_slice, labels, sounds])
534521

535522
def get_data(how_many, offset, model_settings, loss, overlapped_prefix,
536-
shiftby, mode, use_audio, use_video, video_findfile):
523+
shiftby, mode, use_audio, use_video, video_findfile,
524+
augmentation_plugin, augmentation_parameters):
537525
"""Gather sounds from the data set, applying transformations as needed.
538526
539527
When the mode is 'training', a random selection of sounds will be returned,
@@ -575,7 +563,8 @@ def get_data(how_many, offset, model_settings, loss, overlapped_prefix,
575563
mode, use_audio, use_video, video_findfile,
576564
data_index, labels_list, np_rng, data_dir,
577565
audio_read_plugin, audio_read_plugin_kwargs,
578-
video_read_plugin, video_read_plugin_kwargs),
566+
video_read_plugin, video_read_plugin_kwargs,
567+
augmentation_plugin, augmentation_parameters),
579568
daemon=True)
580569
p.start()
581570
processes[mode].append(p)

src/generalize

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,8 @@
4040
# --video_channels=0 \
4141
# --batch_seed=_1 \
4242
# --weights_seed=_1 \
43-
# --augment_volume=1,1 \
44-
# --augment_noise=0,0 \
45-
# --augment_dc=0,0 \
46-
# --augment_reverse=no \
47-
# --augment_invert=no \
43+
# --augmentation_plugin=volume-noise-dc-invert-reverse \
44+
# --augmentation_parameters='{"volume":"1,1", "noise":"0,0", "dc:"0,0", "reverse":"no", "invert":"no"}' \
4845
# --deterministic=0 \
4946
# --igpu=0 \
5047
# --ioffset=3 \
@@ -138,11 +135,8 @@ def main():
138135
"--video_channels="+FLAGS.video_channels,
139136
"--random_seed_batch="+str(FLAGS.batch_seed),
140137
"--random_seed_weights="+str(FLAGS.weights_seed),
141-
"--augment_volume="+str(FLAGS.augment_volume),
142-
"--augment_noise="+str(FLAGS.augment_noise),
143-
"--augment_dc="+str(FLAGS.augment_dc),
144-
"--augment_reverse="+str(FLAGS.augment_reverse),
145-
"--augment_invert="+str(FLAGS.augment_invert),
138+
"--augmentation_plugin="+FLAGS.augmentation_plugin,
139+
"--augmentation_parameters="+FLAGS.augmentation_parameters.replace('<','^^^<').replace('>','^^^>'),
146140
"--deterministic="+FLAGS.deterministic,
147141
"--train_dir="+os.path.join(FLAGS.logdir,"generalize_"+model),
148142
"--summaries_dir="+os.path.join(FLAGS.logdir,"summaries_"+model),
@@ -305,30 +299,15 @@ if __name__ == '__main__':
305299
default=59185,
306300
help='Randomize weight initialization if -1; otherwise use supplied number as seed.')
307301
parser.add_argument(
308-
'--augment_volume',
302+
'--augmentation_plugin',
309303
type=str,
310-
default='1,1',
311-
help='Multiply each annotation by a uniform random number in this interval when training')
312-
parser.add_argument(
313-
'--augment_noise',
314-
type=str,
315-
default='0,0',
316-
help='Add noise to each annotation with a uniform random std dev in this interval when training')
317-
parser.add_argument(
318-
'--augment_dc',
319-
type=str,
320-
default='0,0',
321-
help='Add to each annotation a uniform random number in this interval when training')
322-
parser.add_argument(
323-
'--augment_reverse',
324-
type=str,
325-
default='no',
326-
help='Flip in time with a probability of half each annotation when training')
304+
default='{}',
305+
help='What augmentation plugin to use')
327306
parser.add_argument(
328-
'--augment_invert',
307+
'--augmentation_parameters',
329308
type=str,
330-
default='no',
331-
help='Negate with a probability of half each annotation when training')
309+
default='{}',
310+
help='What augmentation parameters to use')
332311
parser.add_argument(
333312
'--model_architecture',
334313
type=str,

src/gui/controller.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,11 +1224,8 @@ async def train_actuate():
12241224
"--video_channels="+str(M.video_channels), \
12251225
"--batch_seed="+V.batch_seed.value, \
12261226
"--weights_seed="+V.weights_seed.value, \
1227-
"--augment_volume="+V.augment_volume.value, \
1228-
"--augment_noise="+V.augment_noise.value, \
1229-
"--augment_dc="+V.augment_dc.value, \
1230-
"--augment_reverse="+V.augment_reverse.value, \
1231-
"--augment_invert="+V.augment_invert.value, \
1227+
"--augmentation_plugin="+M.augmentation_plugin, \
1228+
"--augmentation_parameters="+json.dumps({k:v.value for k,v in V.augmentation_parameters.items()}), \
12321229
"--deterministic="+M.deterministic, \
12331230
"--igpu=QUEUE1", \
12341231
"--ireplicates="+','.join([str(x) for x in range(ireplicate, min(1+nreplicates, \
@@ -1325,11 +1322,8 @@ async def leaveout_actuate(kind):
13251322
"--video_channels="+str(M.video_channels), \
13261323
"--batch_seed="+V.batch_seed.value, \
13271324
"--weights_seed="+V.weights_seed.value, \
1328-
"--augment_volume="+V.augment_volume.value, \
1329-
"--augment_noise="+V.augment_noise.value, \
1330-
"--augment_dc="+V.augment_dc.value, \
1331-
"--augment_reverse="+V.augment_reverse.value, \
1332-
"--augment_invert="+V.augment_invert.value, \
1325+
"--augmentation_plugin="+M.augmentation_plugin, \
1326+
"--augmentation_parameters="+json.dumps({k:v.value for k,v in V.augmentation_parameters.items()}), \
13331327
"--deterministic="+M.deterministic, \
13341328
"--ioffset="+str(ivalidation_file),
13351329
"--igpu=QUEUE1", \
@@ -1401,11 +1395,8 @@ async def xvalidate_actuate():
14011395
"--video_channels="+str(M.video_channels), \
14021396
"--batch_seed="+V.batch_seed.value, \
14031397
"--weights_seed="+V.weights_seed.value, \
1404-
"--augment_volume="+V.augment_volume.value, \
1405-
"--augment_noise="+V.augment_noise.value, \
1406-
"--augment_dc="+V.augment_dc.value, \
1407-
"--augment_reverse="+V.augment_reverse.value, \
1408-
"--augment_invert="+V.augment_invert.value, \
1398+
"--augmentation_plugin="+M.augmentation_plugin, \
1399+
"--augmentation_parameters="+json.dumps({k:v.value for k,v in V.augmentation_parameters.items()}), \
14091400
"--deterministic="+M.deterministic, \
14101401
"--igpu=QUEUE1", \
14111402
"--kfold="+V.kfold.value, \
@@ -2251,21 +2242,6 @@ def _copy_callback():
22512242
elif "random_seed_weights = " in line:
22522243
m=re.search('random_seed_weights = (.*)', line)
22532244
V.weights_seed.value = m.group(1)
2254-
elif "augment_volume = " in line:
2255-
m=re.search('augment_volume = (.*)', line)
2256-
V.augment_volume.value = m.group(1)
2257-
elif "augment_noise = " in line:
2258-
m=re.search('augment_noise = (.*)', line)
2259-
V.augment_noise.value = m.group(1)
2260-
elif "augment_dc = " in line:
2261-
m=re.search('augment_dc = (.*)', line)
2262-
V.augment_dc.value = m.group(1)
2263-
elif "augment_reverse = " in line:
2264-
m=re.search('augment_reverse = (.*)', line)
2265-
V.augment_reverse.value = m.group(1)
2266-
elif "augment_invert = " in line:
2267-
m=re.search('augment_invert = (.*)', line)
2268-
V.augment_invert.value = m.group(1)
22692245
elif "validate_step_period = " in line:
22702246
m=re.search('validate_step_period = (\d+)', line)
22712247
V.save_and_validate_period.value = m.group(1)
@@ -2312,6 +2288,11 @@ def _copy_callback():
23122288
params = json.loads(m.group(1).replace("'",'"'))
23132289
for k,v in params.items():
23142290
V.model_parameters[k].value = v
2291+
elif "augmentation_parameters = " in line:
2292+
m=re.search('augmentation_parameters = ({.*})', line)
2293+
params = json.loads(m.group(1).replace("'",'"'))
2294+
for k,v in params.items():
2295+
V.augmentation_parameters[k].value = v
23152296
_copy_callback_finalize()
23162297

23172298
def copy_callback():

src/gui/main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
doubleclick_parameters = list(V.doubleclick_parameters.values())
3939
model_parameters = list(V.model_parameters.values())
4040
cluster_parameters = list(V.cluster_parameters.values())
41+
augmentation_parameters = list(V.augmentation_parameters.values())
4142

4243
main_content = row(
4344
column(
@@ -128,12 +129,10 @@
128129
for c in r])
129130
for r in V.cluster_parameters_partitioned],
130131
background="honeydew"),
131-
column(row(V.augment_volume, V.augment_noise,
132-
width=M.gui_width_pix//11*2),
133-
row(V.augment_dc, V.augment_reverse,
134-
width=M.gui_width_pix//11*2),
135-
row(V.augment_invert,
136-
width=M.gui_width_pix//11),
132+
column(*[row(*[column(augmentation_parameters[c],
133+
width=round(M.gui_width_pix/11*V.augmentation_parameters_width[c]))
134+
for c in r])
135+
for r in V.augmentation_parameters_partitioned],
137136
background="azure"),
138137
column(*[row(*[column(model_parameters[c],
139138
width=round(M.gui_width_pix/11*V.model_parameters_width[c]))

src/gui/model.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,6 @@ def save_state_callback():
5757
'nreplicates': V.nreplicates.value,
5858
'batch_seed': V.batch_seed.value,
5959
'weights_seed': V.weights_seed.value,
60-
'augment_volume': V.augment_volume.value,
61-
'augment_noise': V.augment_noise.value,
62-
'augment_dc': V.augment_dc.value,
63-
'augment_reverse': V.augment_reverse.value,
64-
'augment_invert': V.augment_invert.value,
6560
'labels': str.join(',',[x.value for x in V.label_texts]),
6661
'file_dialog_string': V.file_dialog_string.value,
6762
'context': V.context.value,
@@ -72,7 +67,8 @@ def save_state_callback():
7267
**{k:v.value for k,v in V.detect_parameters.items()},
7368
**{k:v.value for k,v in V.doubleclick_parameters.items()},
7469
**{k:v.value for k,v in V.model_parameters.items()},
75-
**{k:v.value for k,v in V.cluster_parameters.items()}},
70+
**{k:v.value for k,v in V.cluster_parameters.items()},
71+
**{k:v.value for k,v in V.augmentation_parameters.items()}},
7672
fid)
7773

7874
def isannotated(sound):
@@ -253,7 +249,7 @@ def init(_bokeh_document, _configuration_file, _use_aitch):
253249
global user_changed_recording, user_copied_parameters
254250
global audio_read, audio_read_exts, audio_read_rec2ch, audio_read_strip_rec, trim_ext
255251
global video_read, detect_labels, doubleclick_annotation, context_data, context_data_istart, model, video_findfile
256-
global detect_parameters, doubleclick_parameters, model_parameters, cluster_parameters
252+
global detect_parameters, doubleclick_parameters, model_parameters, cluster_parameters, augmentation_parameters
257253

258254
bokeh_document = _bokeh_document
259255

@@ -293,6 +289,10 @@ def init(_bokeh_document, _configuration_file, _use_aitch):
293289
tmp = importlib.import_module(os.path.basename(cluster_plugin))
294290
cluster_parameters = tmp.cluster_parameters()
295291

292+
sys.path.insert(0,os.path.dirname(augmentation_plugin))
293+
tmp = importlib.import_module(os.path.basename(augmentation_plugin))
294+
augmentation_parameters = tmp.augmentation_parameters()
295+
296296
sys.path.insert(0,os.path.dirname(video_findfile_plugin))
297297
video_findfile = importlib.import_module(os.path.basename(video_findfile_plugin)).video_findfile
298298

@@ -499,11 +499,6 @@ def is_local_server_or_cluster(varname, varvalue):
499499
'nreplicates':'1', \
500500
'batch_seed':'-1', \
501501
'weights_seed':'-1', \
502-
'augment_volume':'1,1', \
503-
'augment_noise':'0,0', \
504-
'augment_dc':'0,0', \
505-
'augment_reverse':'no', \
506-
'augment_invert':'no', \
507502
'labels':','*(nlabels-1), \
508503
'file_dialog_string':os.getcwd(), \
509504
'context':str(0.2048 / time_scale), \
@@ -514,7 +509,8 @@ def is_local_server_or_cluster(varname, varvalue):
514509
**{x[0]:x[3] for x in detect_parameters}, \
515510
**{x[0]:x[3] for x in doubleclick_parameters}, \
516511
**{x[0]:x[3] for x in model_parameters},
517-
**{x[0]:x[3] for x in cluster_parameters}},
512+
**{x[0]:x[3] for x in cluster_parameters},
513+
**{x[0]:x[3] for x in augmentation_parameters}},
518514
fid)
519515

520516
with open(statepath, 'r') as fid:

0 commit comments

Comments
 (0)