-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathdefault_pretrain.yaml
More file actions
206 lines (173 loc) · 5.86 KB
/
default_pretrain.yaml
File metadata and controls
206 lines (173 loc) · 5.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Name of the run. For output directory base name and wandb.
name: pretrain_mae
# Description of the run. Goes in wandb notes.
notes: null
# Root output directory
output_dir: checkpoints
# Standard image size of fmri flat datasets.
# Used when creating the model for setting up the patch embedding and such.
img_size: [224, 560]
in_chans: 1
patch_size: 16
# Temporal config.
num_frames: 16
t_patch_size: 16
# Mask ratio
# note that the number of visible patches is computed as
# len_keep = (T // t) * (H // h) * (W // w) * (1 - mask_ratio)
# where (T, H, W) is the clip size and (t, h, w) is the patch size.
#
# so for image size (224, 560) we get the following patches per frame:
# - 0.75: 122
# - 0.804: 96
# - 0.869: 64
# - 0.9: 48
# - 0.96: 19
#
# for comparison, original MAE was trained with 0.75 on 224 x 224 (49 patches), while
# MAE-st was trained with 0.9 (19 patches)
mask_ratio: 0.9
# decoder mask ratio for VideoMAEv2 style sparse decoding.
# the number of predicted patches is computed relative to the full image, same as
# mask_ratio. so you need at least pred_mask_ratio > (1 - mask_ratio) to do any sparse
# prediction.
pred_mask_ratio: null
# Constraints for the visible mask:
# - null: uniform random masking
# - tube: patch tube masking, as in VideoMAE
# Default is unconstrained uniform masking.
masking: null
masking_kwargs: {}
# Model
model: mae_vit_base
model_kwargs:
# decoding mode:
# - attn: standard self-attention decoding as in original MAE(-st)
# - cross: cross-attention decoding as in CrossMAE
# - crossreg: cross-register decoding inspired by MAETok
decoding: attn
# position embedding:
# - abs: absolute learned position embedding
# - sep: separable learned position embedding for time and space from MAE-st
# - sincos: fixed sin-cos position encoding
pos_embed: sep
# target normalization:
# - global: normalize by global mean/std
# - frame: normalize for each temporal "frame" of the patch grid
# - patch: normalize for each patch (same as MAE pix norm loss)
target_norm: null
# decoder prediction frame stride from MAE-st. i.e. stride = 2 means predict every
# other frame starting from frame 0.
t_pred_stride: 2
# don't add position encoding to the embeddings for the decoder.
# original MAE does position encode the embeddings, but it might not be needed since
# they are already position encoded from the encoder.
# note this only applies when decoding = attn. for cross and crossreg, the embeddings
# are never position encoded.
no_decode_pos: false
# scale input patches by 1 / observed rate (like dropout)
mask_drop_scale: false
# don't predict pixels along the edges of the visible mask.
# not from MAE but something I've wanted to try out due to observed interpolation
# across patch edges
pred_edge_pad: 0
class_token: true
reg_tokens: 0
# don't add position to cls and reg tokens
no_embed_class: false
decoder_depth: 4
drop_path_rate: 0.0
# Datasets
datasets:
hcp-train:
type: flat-wds
# hcp-flat contains 2000 shards, split into 20 batches of unrelated subjects.
# train on first 18 batches.
url: datasets/hcp-flat/hcp-flat_{0000..1799}.tar
clipping: random
clipping_kwargs:
# sample more clips from each series, to saturate the data loader.
oversample: 4.0
shuffle: true
buffer_size: 1000
samples_per_epoch: 200000
hcp-train-subset:
# clips from shards {0000..0009} of hcp-flat
type: flat-clips
root: datasets/flat-clips/hcp-train-clips-16t
shuffle: false
hcp-val:
# clips from shards {1800..1809} of hcp-flat
type: flat-clips
root: datasets/flat-clips/hcp-val-clips-16t
shuffle: false
nsd-val:
# clips from shards {0000..0009} of nsd-flat
type: flat-clips
root: datasets/flat-clips/nsd-subj01-clips-16t
shuffle: false
# which datasets to include
train_dataset: hcp-train
eval_datasets:
- hcp-train-subset
- hcp-val
- nsd-val
# Data transform
# clip extreme values to [-vmax, vmax]. note that during flat dataset generation, each
# vertex time series is z-scored for each run. so values should be in normal range.
clip_vmax: 3.0
# normalize each video clip.
# - null: no normalization
# - global: the clip is globally normalized to mean zero unit variance.
# - frame: each temporal frame is independently normalized.
# fMRI data have a lot of slow global variation (i.e. the "global signal"). z-scoring
# each clip is one way to try to get the model to focus on more interesting local
# activity variation.
normalize: frame
# cropping options
# random resize crop for training only.
random_crop: false
# conservative default settings. scale 0.9 ~= loss of one patch on height and width.
# keep original aspect ratio (560 / 224 = 2.5).
crop_kwargs:
scale: [0.9, 1.0]
ratio: [2.5, 2.5]
interpolation: 3 # PIL interp constant, 2=BILINEAR, 3=BICUBIC
# Data loader
num_workers: 16
# Optimization
epochs: 100
# batch size per gpu, so total_batch_size = batch_size * accum * num_gpus_per_node * num_nodes
batch_size: 32
accum_iter: 1
# absolute_lr = base_lr * total_batch_size / 256
base_lr: 1e-3
# lower lr bound for cyclic schedulers that hit 0
min_lr: 0.0
warmup_epochs: 5
weight_decay: 0.05
betas: [0.9, 0.95]
clip_grad: 1.0
amp: true
amp_dtype: float16
# Checkpoint to use for init
ckpt: null
# resume training, or restart
resume: true
# Restarting a failed run (with the same output directory) will automatically resume
auto_resume: true
# The start epoch is taken from the checkpoint when resuming, but in case you need to
# manually override.
start_epoch: 0
# The last checkpoint is always saved, this controls how many extra are kept
# By default only keep the last checkpoint
max_checkpoints: 0
checkpoint_period: null
device: cuda
# presend data to cuda asynchronously
presend_cuda: false
seed: 7338
debug: false
wandb: true
wandb_entity: null
wandb_project: fMRI-foundation-model