-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconfig.py
More file actions
193 lines (163 loc) · 5.63 KB
/
config.py
File metadata and controls
193 lines (163 loc) · 5.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Configuration file for (base training / adversarial training / adversarial attack)
import argparse
from distutils import util
# Basic Configuration
parser = argparse.ArgumentParser(description="Base Training")
##########################PATH MERGE##########################
# TODO: merge all the paths as one
# path
parser.add_argument(
"--load_path",
type=str,
default="/load/parent/folder/path/",
help="Path of load parent folder.",
)
parser.add_argument(
"--save_path",
type=str,
default="/save/parent/folder/path/",
help="Path of save parent folder.",
)
parser.add_argument(
"--save_atk",
default=False,
type=lambda x:bool(util.strtobool(x)),
help="save adversarial examples? (True / False)",
)
## split path
# ====> --load_path '/data/inputs_full/'
# ====> --save_path '/data/split/' => use save_path + [train.txt | valid.txt]
## trainer
# basetrain
# ====> --load_path '/data/split/'
# save loader
# ====> --save_path '/data/drum/dataset/' => use save_path + [train/ | valid/]
##generator
# ====> --load_path '/data/MAESTRO/maestro-v2.0.0/'
# ====> --save_path '/data/which_dir/'
## attacker
# ===> --load_path '/data/drum/bestmodel/' => use --load_path + [model/ | dataset/]
# ===> --save_path '/data/attacks/'
##converter
# ====> --load_path "/data/attacks/08-25-00-00/ep0.6/"
# parser.add_argument(
# "--to_convert_path",
# default="/data/attacks/",
# type=str,
# help="Path for 'only MIDIs' to convert. Path CAN contain any folder !!",
# )
## attacked input
# ====> --save_path '/data/attacks/vel_deepfool/' => use save_path + [train/ | valid/]
# parser.add_argument(
# "--attacked_train_input_path",
# type=str,
# default="/data/attacks/vel_deepfool/train/",
# help="Attacked Train input directory.",
# )
# parser.add_argument(
# "--attacked_valid_input_path",
# type=str,
# default="/data/attacks/vel_deepfool/valid/",
# help="Attacked Valid input directory.",
# )
##########################PATH MERGE##########################
parser.add_argument(
"--mode",
default="foo", # force to input mode ^_^
type=str,
help="Mode (basetrain / advtrain / attack / generate / convert / split)",
)
parser.add_argument(
"--composers", default=13, type=int, help="The number of composers.",
)
parser.add_argument(
"--model_name",
type=str,
default="resnet50",
help="Prefix of model name (resnet18 / resnet34 / resnet50 / resnet101 / resnet152 / convnet)",
)
parser.add_argument(
"--optim",
type=str,
default="SGD",
help="Optimizer [Adadelta, Adagrad, Adam, AdamW, SparseAdam, Adamax, ASGD, RMSprop, Rprop, SGD, Nesterov]",
)
parser.add_argument(
"--transform", type=str, default=None, help="Transform mode [Transpose / Tempo]",
)
parser.add_argument(
"--epochs",
default=100,
type=int,
help="Total number of epochs to run. Not actual epoch.",
)
parser.add_argument(
"--trn_seg", default=90, type=int, help="segment number for train."
)
parser.add_argument(
"--val_seg", default=90, type=int, help="segment number for valid."
)
parser.add_argument(
"--train_batch", default=40, type=int, help="Batch size for training"
)
# parser.add_argument("--valid_batch", default=40, type=int, help="Batch size for valid.")
parser.add_argument("--gpu", default="0,1,2,3", type=str, help="GPU id to use.")
parser.add_argument("--lr", default=0.01, type=float, help="Model learning rate.")
parser.add_argument(
"--save_trn", type=lambda x:bool(util.strtobool(x)), default=True, help="Save both model & loader?"
)
parser.add_argument(
"--onset", type=lambda x:bool(util.strtobool(x)), default=True, help="use Onset channel"
)
##attack parameters
# 1. basic configurations
# parser.add_argument("--attack_mode", default="base", type=str, help="Attack Mode (base / trained)") # full name of .pt will tell anyway
parser.add_argument("--target_label", default=None, type=int, help="label for targeted attack")
parser.add_argument(
"--attack_type", default="fgsm", type=str, help="attack (fgsm / deepfool / random)"
)
# 2. data related
parser.add_argument(
"--orig",
default=True,
type=lambda x:bool(util.strtobool(x)),
help="attack on original dataset? (default: True)",
)
# 3. specific attack related
#fgsm
parser.add_argument(
"--epsilons",
default="0.0",
type=str,
help="list of epsilons 'ep0, ep1, ep2..' seperated by ,",
)
parser.add_argument(
"--variable",
type=float,
default=None,
help="Depends on attack type: (randomness ctrl / max # of notes in column / uniform velocity value)"
)
#deepfool
parser.add_argument(
"--max_iter", default=10, type=int, help="max iterations for deepfool attack",
)
parser.add_argument(
"--overshoot", default=5, type=int, help="overshoot for deepfool attack"
)
parser.add_argument("--plot", default=False, type=lambda x:bool(util.strtobool(x)), help="draw plot?")
parser.add_argument(
"--confusion", default=True, type=lambda x:bool(util.strtobool(x)), help="draw confusion matrix?")
##spliter
# use --input_save_path
parser.add_argument(
"--train_percentage", default=0.7, type=float, help="Train data percentage (0 ~ 1)",
)
parser.add_argument(
"--omit", default=None, type=str, help="List of omitted composers' indices.",
)
parser.add_argument(
"--age", default=False, type=lambda x:bool(util.strtobool(x)), help="Classification of Age? (True / False)",
)
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed