Skip to content

Commit c8ec12f

Browse files
committed
pre-commit check
1 parent d92ebbc commit c8ec12f

File tree

4 files changed

+17
-50
lines changed

4 files changed

+17
-50
lines changed

fastvideo/configs/configs.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,6 @@
99
logger = init_logger(__name__)
1010

1111

12-
class DatasetType(str, Enum):
13-
"""
14-
Enumeration for different dataset types.
15-
"""
16-
HF = "hf"
17-
MERGED = "merged"
18-
19-
@classmethod
20-
def from_string(cls, value: str) -> "DatasetType":
21-
"""Convert string to DatasetType enum."""
22-
try:
23-
return cls(value.lower())
24-
except ValueError:
25-
raise ValueError(
26-
f"Invalid dataset type: {value}. Must be one of: {', '.join([m.value for m in cls])}"
27-
) from None
28-
29-
@classmethod
30-
def choices(cls) -> list[str]:
31-
"""Get all available choices as strings for argparse."""
32-
return [dataset_type.value for dataset_type in cls]
33-
34-
3512
class DatasetType(str, Enum):
3613
"""
3714
Enumeration for different dataset types.

fastvideo/configs/models/encoders/clip.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ class CLIPTextConfig(TextEncoderConfig):
8080
"truncation": True,
8181
"max_length": 77,
8282
"return_tensors": "pt"
83-
}
84-
)
83+
})
8584
num_hidden_layers_override: int | None = None
8685
require_post_norm: bool | None = None
8786
prefix: str = "clip"

fastvideo/pipelines/preprocess/hunyuan/hunyuan_preprocess_pipelines.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,8 @@
99

1010
class PreprocessPipelineI2V(ComposedPipelineBase):
1111
_required_config_modules = [
12-
"image_encoder",
13-
"image_processor",
14-
"text_encoder",
15-
"tokenizer",
16-
"text_encoder_2",
17-
"tokenizer_2",
18-
"vae"
12+
"image_encoder", "image_processor", "text_encoder", "tokenizer",
13+
"text_encoder_2", "tokenizer_2", "vae"
1914
]
2015

2116
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
@@ -58,12 +53,9 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
5853

5954
class PreprocessPipelineT2V(ComposedPipelineBase):
6055
_required_config_modules = [
61-
"text_encoder",
62-
"tokenizer",
63-
"text_encoder_2",
64-
"tokenizer_2",
65-
"vae"
56+
"text_encoder", "tokenizer", "text_encoder_2", "tokenizer_2", "vae"
6657
]
58+
6759
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
6860
assert fastvideo_args.preprocess_config is not None
6961
self.add_stage(stage_name="text_transform_stage",
@@ -73,27 +65,27 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
7365
seed=fastvideo_args.preprocess_config.seed,
7466
))
7567
# llama_tokenizer_kwargs = {
76-
# "padding": "max_length",
77-
# "truncation": True,
78-
# "max_length": 256,
68+
# "padding": "max_length",
69+
# "truncation": True,
70+
# "max_length": 256,
7971
# "return_tensors": "pt"
8072
# }
8173
# clip_tokenizer_kwargs = {
82-
# "padding": "max_length",
83-
# "truncation": True,
84-
# "max_length": 77,
74+
# "padding": "max_length",
75+
# "truncation": True,
76+
# "max_length": 77,
8577
# "return_tensors": "pt"
8678
# }
8779
# if len(fastvideo_args.pipeline_config.text_encoder_configs) >= 2:
8880
# fastvideo_args.pipeline_config.text_encoder_configs[0].tokenizer_kwargs = llama_tokenizer_kwargs
8981
# fastvideo_args.pipeline_config.text_encoder_configs[1].tokenizer_kwargs = clip_tokenizer_kwargs
9082
text_encoders = [
91-
self.get_module("text_encoder"),
92-
self.get_module("text_encoder_2")
83+
self.get_module("text_encoder"),
84+
self.get_module("text_encoder_2")
9385
]
9486
tokenizers = [
95-
self.get_module("tokenizer"),
96-
self.get_module("tokenizer_2")
87+
self.get_module("tokenizer"),
88+
self.get_module("tokenizer_2")
9789
]
9890

9991
self.add_stage(stage_name="prompt_encoding_stage",

fastvideo/workflow/preprocess/components.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def __call__(self, batch: dict[str, Any]) -> bool:
7878

7979
def _validate_data_type(self, batch: dict[str, Any]) -> bool:
8080
"""Validate basic validity of data items"""
81-
return not (batch["caption"] is None or batch["caption"] == ""
82-
or "fps" not in batch or batch["fps"] is None or batch["fps"] <= 0
81+
return not (batch["caption"] is None or batch["caption"] == "" or "fps"
82+
not in batch or batch["fps"] is None or batch["fps"] <= 0
8383
or batch["num_frames"] is None or batch["num_frames"] <= 0)
8484

8585
def _validate_resolution(self, batch: dict[str, Any]) -> bool:
@@ -399,7 +399,6 @@ def _default_file_writer_fn(self, args_tuple: tuple) -> int:
399399
return written_count
400400

401401

402-
403402
def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset:
404403
if preprocess_config.dataset_type == DatasetType.HF:
405404
dataset = load_dataset(preprocess_config.dataset_path, split=split)

0 commit comments

Comments
 (0)