Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions config/mpp_avit_ti_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
basic_config: &basic_config
# Run settings
log_to_wandb: !!bool True # Use wandb integration
log_to_wandb: !!bool False # Use wandb integration
log_to_screen: !!bool True # Log progress to screen.
save_checkpoint: !!bool True # Save checkpoints
checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss
Expand All @@ -10,26 +10,26 @@ basic_config: &basic_config
enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now
compile: !!bool False # Compile model - Does not currently work
gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory
exp_dir: '~/MPP' # Output path
exp_dir: '/users/spandit/proj/MPP/results' # Output path
log_interval: 1 # How often to log - Don't think this is actually implemented
pretrained: !!bool False # Whether to load a pretrained model
# wandb settings
project: 'project'
project: 'project'
group: 'debugging'
entity: 'entity'
# Training settings
drop_path: 0.1
batch_size: 1
max_epochs: 500
max_epochs: 200
scheduler_epochs: -1
epoch_size: 2000 # Artificial epoch size
rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1
optimizer: 'adan' # adam, adan, whatever else i end up adding - adan did better on HP sweep
scheduler: 'cosine' # Only cosine implemented
warmup_steps: 1000 # Warmup when not using DAdapt
learning_rate: -1 # -1 means use DAdapt
weight_decay: 1e-3
n_states: 12 # Number of state variables across the datasets - Can be larger than real number and things will just go unused
weight_decay: 1e-3
n_states: 3 # Number of state variables across the datasets - Can be larger than real number and things will just go unused
state_names: ['Pressure', 'Vx', 'Vy', 'Density', 'Vx', 'Vy', 'Density', 'Pressure'] # Should be sorted
dt: 1 # Striding of data - Not currently implemented > 1
n_steps: 16 # Length of history to include in input
Expand All @@ -54,27 +54,19 @@ basic_config: &basic_config
extended_names: !!bool False # Whether to use extended names - not currently implemented
embedding_offset: 0 # Use when adding extra finetuning fields
train_data_paths: [
['~/PDEBench/2D/shallow-water', 'swe', ''],
['~/PDEBench/2D/NS_incom', 'incompNS', ''],
['~/PDEBench/2D/CFD/2D_Train_Rand', compNS, '128'],
['~/PDEBench/2D/CFD/2D_Train_Rand', compNS, '512'],
['~/PDEBench/2D/CFD/2D_Train_Turb', compNS, ''],
['~/PDEBench/2D/diffusion-reaction', 'diffre2d', ''],
['/lustre/scratch5/exempt/artimis/data/pdebench/2D/shallow-water', 'swe', ''],
['/lustre/scratch5/exempt/artimis/data/pdebench/2D/diffusion-reaction', 'diffre2d', ''],
]
valid_data_paths: [
['~/PDEBench/2D/shallow-water', 'swe', ''],
['~/PDEBench/2D/NS_incom', 'incompNS', ''],
['~/PDEBench/2D/CFD/2D_Train_Rand', compNS, '128'],
['~/PDEBench/2D/CFD/2D_Train_Rand', compNS, '512'],
['~/PDEBench/2D/CFD/2D_Train_Turb', compNS, ''],
['~/PDEBench/2D/diffusion-reaction', 'diffre2d', ''],
['/lustre/scratch5/exempt/artimis/data/pdebench/2D/shallow-water', 'swe', ''],
['/lustre/scratch5/exempt/artimis/data/pdebench/2D/diffusion-reaction', 'diffre2d', ''],
]
append_datasets: [] # List of datasets to append to the input/output projections for finetuning


finetune: &finetune
<<: *basic_config
max_epochs: 500
max_epochs: 200
train_val_test: [.8, .1, .1]
accum_grad: 1
pretrained: !!bool True
Expand All @@ -90,7 +82,7 @@ finetune: &finetune
freeze_middle: !!bool False # Whether to freeze the middle layers of the model
freeze_processor: !!bool False
append_datasets: [] # List of datasets to append to the input/output projections for finetuning


frozen: &frozen
<<: *finetune
Expand All @@ -100,4 +92,4 @@ frozen: &frozen
less_frozen: &less_frozen
<<: *finetune
freeze_middle: !!bool True # Whether to freeze the middle layers of the model
freeze_processor: !!bool True
freeze_processor: !!bool True
97 changes: 97 additions & 0 deletions config/mpp_lsc240420_avit_ti_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
basic_config: &basic_config
# Run settings
log_to_wandb: !!bool False # Use wandb integration
log_to_screen: !!bool True # Log progress to screen.
save_checkpoint: !!bool True # Save checkpoints
checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss
debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging
true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs
num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio
enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now
compile: !!bool False # Compile model - Does not currently work
gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory
exp_dir: '/users/spandit/proj/MPP/results' # Output path
log_interval: 1 # How often to log - Don't think this is actually implemented
pretrained: !!bool False # Whether to load a pretrained model
# wandb settings
project: 'project'
group: 'debugging'
entity: 'entity'
# Training settings
drop_path: 0.1
batch_size: 1
max_epochs: 200
scheduler_epochs: -1
epoch_size: 2000 # Artificial epoch size
rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1
optimizer: 'adan' # adam, adan, whatever else i end up adding - adan did better on HP sweep
scheduler: 'cosine' # Only cosine implemented
warmup_steps: 1000 # Warmup when not using DAdapt
learning_rate: -1 # -1 means use DAdapt
weight_decay: 1e-3
n_states: 9 # Number of state variables across the datasets - Can be larger than real number and things will just go unused
state_names: [
'Uvelocity', 'Wvelocity', 'av_density',
'density_case', 'density_cushion', 'density_maincharge',
'density_outside_air', 'density_striker', 'density_throw'
]
dt: 1 # Striding of data - Not currently implemented > 1
n_steps: 16 # Length of history to include in input
enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception.
accum_grad: 5 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum
# Model settings
model_type: 'avit' # Only option so far
block_type: 'axial' # Which type of block to use - if axial, next two fields must be set to define axial ops
time_type: 'attention' # Conditional on block type
space_type: 'axial_attention' # Conditional on block type
tie_fields: !!bool False # Whether to use 1 embedding per field per data
embed_dim: 192 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L
num_heads: 3 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L
processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L
patch_size: [4, 4] # Actually currently hardcoded at 16
bias_type: 'rel' # Options rel, continuous, none
# Data settings
train_val_test: [.8, .1, .1]
augmentation: !!bool False # Augmentation not implemented
use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets
tie_batches: !!bool False # Force everything in batch to come from one dset
extended_names: !!bool False # Whether to use extended names - not currently implemented
embedding_offset: 0 # Use when adding extra finetuning fields
train_data_paths: [
['/Users/spandit/proj/MPP/lsc240420-2-hdf5/LSC_HDF5', 'lsc240420', ''],
]
valid_data_paths: [
['/Users/spandit/proj/MPP/lsc240420-2-hdf5/LSC_HDF5', 'lsc240420', ''],
]
append_datasets: [] # List of datasets to append to the input/output projections for finetuning


finetune: &finetune
<<: *basic_config
max_epochs: 200
train_val_test: [.8, .1, .1]
accum_grad: 1
pretrained: !!bool True
group: 'debugging'
pretrained_ckpt_path: '/B16-noNS/training_checkpoints/ckpt.tar'
train_data_paths: [
['/PDEBench/2D/CFD/2D_Train_Turb', 'compNS', 'M1.0'],
]
valid_data_paths: [ # These are the same for all configs - uses split according to train_val_test
['/PDEBench/2D/CFD/2D_Train_Turb', 'compNS', 'M1.0'],
]
embedding_offset: 0 # Number of fields in original model - FT fields start after this
freeze_middle: !!bool False # Whether to freeze the middle layers of the model
freeze_processor: !!bool False
append_datasets: [] # List of datasets to append to the input/output projections for finetuning


frozen: &frozen
<<: *finetune
freeze_middle: !!bool True # Whether to freeze the middle layers of the model
freeze_processor: !!bool False

less_frozen: &less_frozen
<<: *finetune
freeze_middle: !!bool True # Whether to freeze the middle layers of the model
freeze_processor: !!bool True
23 changes: 14 additions & 9 deletions data_utils/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
"""
Remember to parameterize the file paths eventually
"""
import torch
Expand All @@ -8,8 +8,8 @@
from torch.utils.data.distributed import DistributedSampler
import os
try:
from mixed_dset_sampler import MultisetSampler
from hdf5_datasets import *
from .mixed_dset_sampler import MultisetSampler
from .hdf5_datasets import *
except ImportError:
from .mixed_dset_sampler import MultisetSampler
from .hdf5_datasets import *
Expand All @@ -19,6 +19,7 @@
broken_paths = []
# IF YOU ADD A NEW DSET MAKE SURE TO UPDATE THIS MAPPING SO MIXED DSET KNOWS HOW TO USE IT
DSET_NAME_TO_OBJECT = {
'lsc240420': LSC240420Dataset,
'swe': SWEDataset,
'incompNS': IncompNSDataset,
'diffre2d': DiffRe2DDataset,
Expand All @@ -28,7 +29,7 @@
def get_data_loader(params, paths, distributed, split='train', rank=0, train_offset=0):
# paths, types, include_string = zip(*paths)
dataset = MixedDataset(paths, n_steps=params.n_steps, train_val_test=params.train_val_test, split=split,
tie_fields=params.tie_fields, use_all_fields=params.use_all_fields, enforce_max_steps=params.enforce_max_steps,
tie_fields=params.tie_fields, use_all_fields=params.use_all_fields, enforce_max_steps=params.enforce_max_steps,
train_offset=train_offset)
# dataset = IncompNSDataset(paths[0], n_steps=params.n_steps, train_val_test=params.train_val_test, split=split)
seed = torch.random.seed() if 'train'==split else 0
Expand All @@ -37,7 +38,7 @@ def get_data_loader(params, paths, distributed, split='train', rank=0, train_off
else:
base_sampler = RandomSampler
sampler = MultisetSampler(dataset, base_sampler, params.batch_size,
distributed=distributed, max_samples=params.epoch_size,
distributed=distributed, max_samples=params.epoch_size,
rank=rank)
# sampler = DistributedSampler(dataset) if distributed else None
dataloader = DataLoader(dataset,
Expand All @@ -48,14 +49,14 @@ def get_data_loader(params, paths, distributed, split='train', rank=0, train_off
drop_last=True,
pin_memory=torch.cuda.is_available())
return dataloader, dataset, sampler


class MixedDataset(Dataset):
def __init__(self, path_list=[], n_steps=1, dt=1, train_val_test=(.8, .1, .1),
split='train', tie_fields=True, use_all_fields=True, extended_names=False,
split='train', tie_fields=True, use_all_fields=True, extended_names=False,
enforce_max_steps=False, train_offset=0):
super().__init__()
# Global dicts used by Mixed DSET.
# Global dicts used by Mixed DSET.
self.train_offset = train_offset
self.path_list, self.type_list, self.include_string = zip(*path_list)
self.tie_fields = tie_fields
Expand All @@ -80,6 +81,10 @@ def __init__(self, path_list=[], n_steps=1, dt=1, train_val_test=(.8, .1, .1),

self.subset_dict = self._build_subset_dict()

print("\n[DEBUG] Final subset_dict:")
for k, v in self.subset_dict.items():
print(f" {k}: {v}")

def get_state_names(self):
name_list = []
if self.use_all_fields:
Expand Down Expand Up @@ -131,6 +136,6 @@ def __getitem__(self, index):
print('FAILED AT ', file_idx, local_idx, index,int(os.environ.get("RANK", 0)))
thisvariabledoesntexist
return x, file_idx, torch.tensor(self.subset_dict[self.sub_dsets[file_idx].get_name()]), bcs, y

def __len__(self):
return sum([len(dset) for dset in self.sub_dsets])
Loading