-
Notifications
You must be signed in to change notification settings - Fork 12
Description
Recently had an issue with sizing for several models. These bugs are only discovered much further downstream from training.
To prevent this in the future, we'd like to add type checking that uses config elements like sample_size and latent_size to compute expected shapes for latent tensors such that they can be verified during training. I.e. if an image vae has sample_size [360,640] we need to verify input is actually that shape (since conv layers will work either way without throwing errors). We also need to (this is the most important part) verify latents are actually latent_size
This change should be added to all models, but for the time being the most important ones would be owl_vaes/models/oobleck.py and owl_vaes/models/dcae.py
See some example configs for each:
configs/audio_ae_2.yml and configs/cod_128x_depth.yml.