Skip to content

Commit 66796b5

Browse files
author
Huy Vu2
committed
fix Ruff, lint
1 parent e936907 commit 66796b5

File tree

7 files changed

+32
-16
lines changed

7 files changed

+32
-16
lines changed

dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import random
216
from abc import ABC, abstractmethod
317
from typing import List

dfm/src/megatron/data/wan/wan_energon_datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
from dataclasses import dataclass
1818

19-
from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider
19+
from megatron.bridge.data.utils import DatasetBuildContext
2020
from torch import int_repr
2121

22-
from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModuleConfig, DiffusionDataModule
22+
from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule, DiffusionDataModuleConfig
2323
from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder
2424

2525

dfm/src/megatron/data/wan/wan_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16+
1617
from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample
1718

1819

dfm/src/megatron/data/wan/wan_taskencoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414

1515
# pylint: disable=C0115,C0116,C0301
1616

17-
from torch._tensor import Tensor
18-
import torch
19-
import torch.nn.functional as F
20-
from megatron.energon.task_encoder.base import stateless
21-
from megatron.core import parallel_state
2217
from typing import List
23-
from megatron.energon import SkipSample
18+
2419
from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking
25-
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys
2620
from dfm.src.megatron.data.wan.wan_sample import WanSample
2721
from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify
22+
from megatron.core import parallel_state
23+
from megatron.energon import SkipSample
24+
from megatron.energon.task_encoder.base import stateless
25+
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys
26+
import torch
27+
import torch.nn.functional as F
2828

2929

3030
def cook(sample: dict) -> dict:

dfm/src/megatron/model/wan/rope_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device):
6969

7070
# Pad freqs_real_i to (padded_seq_len, 1, 1, dim_head) with 0s
7171
for i, freqs_real_i in enumerate(freqs_real):
72-
seq_len_q_padded = cu_seqlens_q_padded[i+1] - cu_seqlens_q_padded[i]
72+
seq_len_q_padded = cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i]
7373
if freqs_real_i.shape[0] < seq_len_q_padded:
7474
pad_shape = (seq_len_q_padded - freqs_real_i.shape[0], 1, 1, dim_head)
7575
freqs_real_i = torch.cat(

dfm/src/megatron/recipes/wan/wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def pretrain_config(
180180
global_batch_size=global_batch_size,
181181
num_workers=10,
182182
task_encoder_seq_length=None,
183-
packing_buffer_size=131072, # 131,072 = 2^17 tokens, each 5 secs of 832*480 is about 45k tokens
183+
packing_buffer_size=131072, # 131,072 = 2^17 tokens, each 5 secs of 832*480 is about 45k tokens
184184
)
185185

186186
# Config Container

examples/megatron/recipes/wan/prepare_energon_dataset_wan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from typing import Dict, List, Optional, Tuple
1919

2020
import cv2
21+
from diffusers import AutoencoderKLWan
2122
import numpy as np
2223
import torch
23-
import webdataset as wds
24-
25-
from diffusers import AutoencoderKLWan
2624
from transformers import AutoTokenizer, UMT5EncoderModel
25+
import webdataset as wds
2726

2827

2928
def _map_interpolation(resize_mode: str) -> int:
@@ -412,7 +411,7 @@ def main():
412411
for index, meta in enumerate(metadata_list):
413412
video_name = meta["file_name"]
414413
start_frame = int(meta["start_frame"]) # inclusive
415-
end_frame = int(meta["end_frame"]) # inclusive
414+
end_frame = int(meta["end_frame"]) # inclusive
416415
caption_text = meta.get("vila_caption", "")
417416

418417
video_path = str(video_folder / video_name)
@@ -431,7 +430,9 @@ def main():
431430

432431
# Encode text and video with HF models exactly like automodel
433432
text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text)
434-
latents = _encode_video_latents(vae, args.device, video_tensor, deterministic_latents=not args.stochastic)
433+
latents = _encode_video_latents(
434+
vae, args.device, video_tensor, deterministic_latents=not args.stochastic
435+
)
435436

436437
# Move to CPU without changing dtype; keep exact values to match automodel outputs
437438
text_embed_cpu = text_embed.detach().to(device="cpu")

0 commit comments

Comments
 (0)