|
7 | 7 | """ |
8 | 8 |
|
9 | 9 | import logging |
10 | | -import os |
11 | 10 | import random |
12 | 11 | from contextlib import contextmanager |
13 | 12 | from functools import partial |
14 | 13 |
|
15 | 14 | import numpy as np |
16 | | -from einops import rearrange, repeat |
17 | | -from tqdm import tqdm |
18 | | - |
19 | | -mainlogger = logging.getLogger("mainlogger") |
20 | | - |
21 | 15 | import peft |
22 | 16 | import pytorch_lightning as pl |
23 | 17 | import torch |
24 | | -import torch.nn as nn |
| 18 | +from einops import rearrange, repeat |
25 | 19 | from pytorch_lightning.utilities import rank_zero_only |
26 | 20 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR |
27 | 21 | from torchvision.utils import make_grid |
| 22 | +from tqdm import tqdm |
28 | 23 |
|
29 | 24 | from videotuna.base.ddim import DDIMSampler |
30 | | -from videotuna.base.distributions import DiagonalGaussianDistribution, normal_kl |
| 25 | +from videotuna.base.distributions import DiagonalGaussianDistribution |
31 | 26 | from videotuna.base.ema import LitEma |
32 | | -from videotuna.base.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr |
33 | 27 |
|
34 | 28 | # import rlhf utils |
35 | 29 | from videotuna.lvdm.models.rlhf_utils.batch_ddim import batch_ddim_sampling |
36 | 30 | from videotuna.lvdm.models.rlhf_utils.reward_fn import aesthetic_loss_fn |
37 | 31 | from videotuna.lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler |
38 | | -from videotuna.lvdm.modules.utils import ( |
39 | | - default, |
40 | | - disabled_train, |
41 | | - exists, |
42 | | - extract_into_tensor, |
43 | | - noise_like, |
44 | | -) |
| 32 | +from videotuna.lvdm.modules.utils import default, disabled_train, extract_into_tensor |
45 | 33 | from videotuna.utils.common_utils import instantiate_from_config |
46 | 34 |
|
47 | 35 | __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} |
48 | 36 |
|
49 | 37 |
|
| 38 | +mainlogger = logging.getLogger("mainlogger") |
| 39 | + |
| 40 | + |
50 | 41 | class DDPMFlow(pl.LightningModule): |
51 | 42 | # classic DDPM with Gaussian diffusion, in image space |
52 | 43 | def __init__( |
@@ -430,7 +421,7 @@ def load_lora_from_ckpt(self, model, path): |
430 | 421 | f"Parameter {key} from lora_state_dict was not copied to the model." |
431 | 422 | ) |
432 | 423 | # print(f"Parameter {key} from lora_state_dict was not copied to the model.") |
433 | | - print(f"All Parameters was copied successfully.") |
| 424 | + print("All Parameters was copied successfully.") |
434 | 425 |
|
435 | 426 | def inject_lora(self): |
436 | 427 | """inject lora into the denoising module. |
@@ -519,7 +510,7 @@ def __init__( |
519 | 510 |
|
520 | 511 | try: |
521 | 512 | self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 |
522 | | - except: |
| 513 | + except Exception: |
523 | 514 | self.num_downs = 0 |
524 | 515 | if not scale_by_std: |
525 | 516 | self.scale_factor = scale_factor |
@@ -1586,7 +1577,7 @@ def configure_optimizers(self): |
1586 | 1577 |
|
1587 | 1578 | if self.cond_stage_trainable: |
1588 | 1579 | params_cond_stage = [ |
1589 | | - p for p in self.cond_stage_model.parameters() if p.requires_grad == True |
| 1580 | + p for p in self.cond_stage_model.parameters() if p.requires_grad is True |
1590 | 1581 | ] |
1591 | 1582 | mainlogger.info( |
1592 | 1583 | f"@Training [{len(params_cond_stage)}] Paramters for Cond_stage_model." |
|
0 commit comments