-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·170 lines (130 loc) · 4.45 KB
/
utils.py
File metadata and controls
executable file
·170 lines (130 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import matplotlib.pyplot as plt
import os
from os.path import join, basename, dirname, exists
import torch
from torchvision import transforms
import json
import pandas as pd
def get_paths(folder_path, recurse=True, extensions=''):
"""
Grabs all relevant file paths from folder_path with extension.
Parameters
----------
folder_path : PATH-STR
Path to parent folder.
recurse : BOOL, optional
T/F if we want to recurse all subdirectories
extension : TUPLE, optional
Exclusive tuple of extensions. The default is None.
Returns
-------
List of absolute paths.
"""
# Make sure directory exists
if not os.path.exists(folder_path):
print('ERROR: FOLDER_PATH NOT FOUND')
return []
file_paths = []
if recurse:
for folder, subs, files in os.walk(folder_path):
file_paths.extend([os.path.join(folder, file) for file in files if file.endswith(extensions)])
else:
file_paths = [os.path.join(folder_path, path) for path in os.listdir(folder_path) if path.endswith(extensions)]
return file_paths
def plot_one_cadence(cadence, cmap = 'plasma'):
"""
Plots one "cadence" sample from the SETI dataset.
Parameters
----------
cadence : NUMPY ARRAY
Cadence array.
cmap : STRING
colormap string if desired.
Returns
-------
None.
"""
# Grab number of snippets, create subplot
num_snippets = cadence.shape[0]
plt.figure()
plt.suptitle('Cadence')
plt.xlabel('Frequency')
# Loop through each snippet, plotting
for snippet in range(num_snippets):
plt.subplot(num_snippets, 1, snippet + 1)
plt.imshow(cadence[snippet, :, :].astype(float), cmap=cmap, aspect='auto')
plt.text(5, 100, ['ON', 'OFF'][snippet % 2], bbox={'facecolor': 'white'})
plt.xticks([])
plt.show()
return
def get_training_augmentations(hyp):
"""
Creates a function to perform all training image augmentations.
Returns
-------
Function.
"""
# Create sequential transforms
augmentations = torch.nn.Sequential(
transforms.Resize(hyp['image_size']),
transforms.RandomRotation(hyp['rotation_degrees']),
transforms.RandomHorizontalFlip(p=hyp['horizontal_flip_prob']),
transforms.RandomVerticalFlip(p=hyp['vertical_flip_prob']),
transforms.Normalize(mean=[0.456, 0.450, 0.443, 0.423, 0.422, 0.423],
std=[0.244, 0.243, 0.245, 0.245, 0.249, 0.2409],
inplace=True)
)
return augmentations
def get_validation_augmentations(hyp):
"""
Creates a function to perform all validation augmentations
Parameters
----------
image_size : INT
Size of image.
Returns
-------
Function.
"""
# Create sequential transforms
augmentations = torch.nn.Sequential(
transforms.Resize(hyp['image_size']))
return augmentations
def get_files_paths_and_labels(data_folder):
"""
Gets all file paths for data and labels.
Parameters
----------
data_folder : Path, STR
Path to training folder.
Returns
-------
data_file_paths : LIST
List of paths to data.
targets : LIST
List of target values.
"""
# Check for cached file names
cache_fn = join(data_folder, "file_paths.json")
if exists(cache_fn):
with open(cache_fn, 'r') as fp:
data = json.load(fp)
data_file_paths = data['file_paths']
targets = data['targets']
else:
# Grab all relevant file paths
data_file_paths = get_paths(data_folder, extensions=('.npy'))
# Open up labels path
labels_path = join(dirname(data_folder), "train_labels.csv")
labels = pd.read_csv(labels_path, dtype={'id': str, 'target':int})
labels = dict(labels.values)
# Iterate through each label file, accumulate the label for it
targets = []
for file in data_file_paths:
# Grab just the basename, no extension
file_key = basename(file).split('.')[0]
targets.extend([labels[file_key]])
# Cache for later use
with open(cache_fn, 'w') as fp:
json.dump({'file_paths':data_file_paths, 'targets':targets}, fp)
return data_file_paths, targets