-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
60 lines (43 loc) · 1.84 KB
/
utils.py
File metadata and controls
60 lines (43 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
### Utils
import h5py
import os
import matplotlib.pyplot as plt
import numpy as np
def load_it_data(path_to_data):
""" Load IT data
Args:
path_to_data (str): Path to the data
Returns:
np.array (x6): Stimulus train/val/test; objects list train/val/test; spikes train/val
"""
datafile = h5py.File(os.path.join(path_to_data,'IT_data.h5'), 'r')
stimulus_train = datafile['stimulus_train'][()]
spikes_train = datafile['spikes_train'][()]
objects_train = datafile['object_train'][()]
stimulus_val = datafile['stimulus_val'][()]
spikes_val = datafile['spikes_val'][()]
objects_val = datafile['object_val'][()]
stimulus_test = datafile['stimulus_test'][()]
objects_test = datafile['object_test'][()]
### Decode back object type to latin
objects_train = [obj_tmp.decode("latin-1") for obj_tmp in objects_train]
objects_val = [obj_tmp.decode("latin-1") for obj_tmp in objects_val]
objects_test = [obj_tmp.decode("latin-1") for obj_tmp in objects_test]
return stimulus_train, stimulus_val, stimulus_test, objects_train, objects_val, objects_test, spikes_train, spikes_val
def visualize_img(stimulus,objects,stim_idx):
"""Visualize image given the stimulus and corresponding index and the object name.
Args:
stimulus (array of float): Stimulus containing all the images
objects (list of str): Object list containing all the names
stim_idx (int): Index of the stimulus to plot
"""
normalize_mean=[0.485, 0.456, 0.406]
normalize_std=[0.229, 0.224, 0.225]
img_tmp = np.transpose(stimulus[stim_idx],[1,2,0])
### Go back from normalization
img_tmp = (img_tmp*normalize_std + normalize_mean) * 255
plt.figure()
plt.imshow(img_tmp.astype(np.uint8),cmap='gray')
plt.title(str(objects[stim_idx]))
plt.show()
return