File tree Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -899,6 +899,7 @@ def __init__(
899899 augment_horizontal_flip = True ,
900900 train_lr = 1e-4 ,
901901 train_num_steps = 100000 ,
902+ max_grad_norm = 1. ,
902903 ema_update_every = 10 ,
903904 ema_decay = 0.995 ,
904905 betas = (0.9 , 0.99 ),
@@ -926,6 +927,7 @@ def __init__(
926927
927928 self .batch_size = train_batch_size
928929 self .gradient_accumulate_every = gradient_accumulate_every
930+ self .max_grad_norm = max_grad_norm
929931
930932 self .train_num_steps = train_num_steps
931933 self .image_size = diffusion_model .image_size
@@ -1013,6 +1015,7 @@ def train(self):
10131015 pbar .set_description (f'loss: { total_loss :.4f} ' )
10141016
10151017 accelerator .wait_for_everyone ()
1018+ accelerator .clip_grad_norm_ (self .model .parameters (), self .max_grad_norm )
10161019
10171020 self .opt .step ()
10181021 self .opt .zero_grad ()
Original file line number Diff line number Diff line change 33setup (
44 name = 'RIN-pytorch' ,
55 packages = find_packages (exclude = []),
6- version = '0.7.6 ' ,
6+ version = '0.7.7 ' ,
77 license = 'MIT' ,
88 description = 'RIN - Recurrent Interface Network - Pytorch' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments