Skip to content

Commit bf4de1f

Browse files
[chore] Upgrade min Python version from 3.8 to 3.10 (#597)
1 parent 66fdcc8 commit bf4de1f

File tree

115 files changed

+1121
-1163
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

115 files changed

+1121
-1163
lines changed

comfyui/video_generator/load_image.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import hashlib
22
import os
3-
from typing import List
43

54
import folder_paths
65
import numpy as np
@@ -39,8 +38,8 @@ def load_image(self, image):
3938

4039
img = pillow(Image.open, image_path)
4140

42-
output_images: List[torch.Tensor] = []
43-
output_masks: List[torch.Tensor] = []
41+
output_images: list[torch.Tensor] = []
42+
output_masks: list[torch.Tensor] = []
4443
w, h = None, None
4544

4645
excluded_formats = ['MPO']

comfyui/video_generator/node_helpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
2-
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
2+
from collections.abc import Callable
3+
from typing import Any, TypeVar
34

45
import torch
56
from comfy.cli_args import args
@@ -8,9 +9,8 @@
89
T = TypeVar('T')
910

1011

11-
def conditioning_set_values(
12-
conditioning: List[Any],
13-
values: Optional[Dict[str, Any]] = None) -> List[Any]:
12+
def conditioning_set_values(conditioning: list[Any],
13+
values: dict[str, Any] | None = None) -> list[Any]:
1414
if values is None:
1515
values = {}
1616
c = []
@@ -48,7 +48,7 @@ def hasher() -> Callable[[], Any]:
4848
return hashfuncs[args.default_hashing_function]
4949

5050

51-
def string_to_torch_dtype(string: str) -> Optional[torch.dtype]:
51+
def string_to_torch_dtype(string: str) -> torch.dtype | None:
5252
if string == "fp32":
5353
return torch.float32
5454
if string == "fp16":
@@ -59,7 +59,7 @@ def string_to_torch_dtype(string: str) -> Optional[torch.dtype]:
5959

6060

6161
def image_alpha_fix(destination: torch.Tensor,
62-
source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
62+
source: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
6363
if destination.shape[-1] < source.shape[-1]:
6464
source = source[..., :destination.shape[-1]]
6565
elif destination.shape[-1] > source.shape[-1]:

docs/source/conf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import re
1919
import sys
2020
from pathlib import Path
21-
from typing import Optional
2221

2322
import requests
2423

@@ -168,8 +167,7 @@ def setup(app):
168167
_cached_branch: str = ""
169168

170169

171-
def get_repo_base_and_branch(
172-
pr_number: str) -> tuple[Optional[str], Optional[str]]:
170+
def get_repo_base_and_branch(pr_number: str) -> tuple[str | None, str | None]:
173171
global _cached_base, _cached_branch
174172
if _cached_base and _cached_branch:
175173
return _cached_base, _cached_branch

docs/source/generate_examples.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import re
66
from dataclasses import dataclass, field
77
from pathlib import Path
8-
from typing import Optional
98

109
ROOT_DIR = Path(__file__).parent.parent.parent.resolve()
1110
ROOT_DIR_RELATIVE = '../../../..'
@@ -89,7 +88,7 @@ class Example:
8988
generate() -> str: Generates the documentation content.
9089
""" # noqa: E501
9190
path: Path
92-
category: Optional[str] = None
91+
category: str | None = None
9392
main_file: Path = field(init=False)
9493
other_files: list[Path] = field(init=False)
9594
title: str = field(init=False)

fastvideo/v1/STA_configuration.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
from collections import defaultdict
5-
from typing import Any, Dict, List, Optional, Tuple
5+
from typing import Any
66

77
import numpy as np
88

@@ -13,7 +13,7 @@ def configure_sta(mode: str = 'STA_searching',
1313
layer_num: int = 40,
1414
time_step_num: int = 50,
1515
head_num: int = 40,
16-
**kwargs) -> List[List[List[Any]]]:
16+
**kwargs) -> list[list[list[Any]]]:
1717
"""
1818
Configure Sliding Tile Attention (STA) parameters based on the specified mode.
1919
@@ -53,22 +53,22 @@ def configure_sta(mode: str = 'STA_searching',
5353

5454
if mode == 'STA_searching':
5555
# Get parameters with defaults
56-
mask_candidates: Optional[List[str]] = kwargs.get('mask_candidates')
56+
mask_candidates: list[str] | None = kwargs.get('mask_candidates')
5757
if mask_candidates is None:
5858
raise ValueError(
5959
"mask_candidates is required for STA_searching mode")
60-
mask_selected: List[int] = kwargs.get('mask_selected',
60+
mask_selected: list[int] = kwargs.get('mask_selected',
6161
list(range(len(mask_candidates))))
6262

6363
# Parse selected masks
64-
selected_masks: List[List[int]] = []
64+
selected_masks: list[list[int]] = []
6565
for index in mask_selected:
6666
mask = mask_candidates[index]
6767
masks_list = [int(x) for x in mask.split(',')]
6868
selected_masks.append(masks_list)
6969

7070
# Create 3D mask structure with fixed dimensions (t=50, l=60)
71-
masks_3d: List[List[List[List[int]]]] = []
71+
masks_3d: list[list[list[list[int]]]] = []
7272
for i in range(time_step_num): # Fixed t dimension = 50
7373
row = []
7474
for j in range(layer_num): # Fixed l dimension = 60
@@ -79,25 +79,23 @@ def configure_sta(mode: str = 'STA_searching',
7979

8080
elif mode == 'STA_tuning':
8181
# Get required parameters
82-
mask_search_files_path: Optional[str] = kwargs.get(
82+
mask_search_files_path: str | None = kwargs.get(
8383
'mask_search_files_path')
8484
if not mask_search_files_path:
8585
raise ValueError(
8686
"mask_search_files_path is required for STA_tuning mode")
8787

8888
# Get optional parameters with defaults
89-
mask_candidates_tuning: Optional[List[str]] = kwargs.get(
90-
'mask_candidates')
89+
mask_candidates_tuning: list[str] | None = kwargs.get('mask_candidates')
9190
if mask_candidates_tuning is None:
9291
raise ValueError("mask_candidates is required for STA_tuning mode")
93-
mask_selected_tuning: List[int] = kwargs.get(
92+
mask_selected_tuning: list[int] = kwargs.get(
9493
'mask_selected', list(range(len(mask_candidates_tuning))))
95-
skip_time_steps_tuning: Optional[int] = kwargs.get('skip_time_steps')
96-
save_dir_tuning: Optional[str] = kwargs.get('save_dir',
97-
"mask_candidates")
94+
skip_time_steps_tuning: int | None = kwargs.get('skip_time_steps')
95+
save_dir_tuning: str | None = kwargs.get('save_dir', "mask_candidates")
9896

9997
# Parse selected masks
100-
selected_masks_tuning: List[List[int]] = []
98+
selected_masks_tuning: list[list[int]] = []
10199
for index in mask_selected_tuning:
102100
mask = mask_candidates_tuning[index]
103101
masks_list = [int(x) for x in mask.split(',')]
@@ -108,7 +106,7 @@ def configure_sta(mode: str = 'STA_searching',
108106
averaged_results = average_head_losses(results, selected_masks_tuning)
109107

110108
# Add full attention mask for specific cases
111-
full_attention_mask_tuning: Optional[List[int]] = kwargs.get(
109+
full_attention_mask_tuning: list[int] | None = kwargs.get(
112110
'full_attention_mask')
113111
if full_attention_mask_tuning is not None:
114112
selected_masks_tuning.append(full_attention_mask_tuning)
@@ -149,28 +147,28 @@ def configure_sta(mode: str = 'STA_searching',
149147
return mask_strategy_3d
150148
elif mode == 'STA_tuning_cfg':
151149
# Get required parameters for both positive and negative paths
152-
mask_search_files_path_pos: Optional[str] = kwargs.get(
150+
mask_search_files_path_pos: str | None = kwargs.get(
153151
'mask_search_files_path_pos')
154-
mask_search_files_path_neg: Optional[str] = kwargs.get(
152+
mask_search_files_path_neg: str | None = kwargs.get(
155153
'mask_search_files_path_neg')
156-
save_dir_cfg: Optional[str] = kwargs.get('save_dir')
154+
save_dir_cfg: str | None = kwargs.get('save_dir')
157155

158156
if not mask_search_files_path_pos or not mask_search_files_path_neg or not save_dir_cfg:
159157
raise ValueError(
160158
"mask_search_files_path_pos, mask_search_files_path_neg, and save_dir are required for STA_tuning_cfg mode"
161159
)
162160

163161
# Get optional parameters with defaults
164-
mask_candidates_cfg: Optional[List[str]] = kwargs.get('mask_candidates')
162+
mask_candidates_cfg: list[str] | None = kwargs.get('mask_candidates')
165163
if mask_candidates_cfg is None:
166164
raise ValueError(
167165
"mask_candidates is required for STA_tuning_cfg mode")
168-
mask_selected_cfg: List[int] = kwargs.get(
166+
mask_selected_cfg: list[int] = kwargs.get(
169167
'mask_selected', list(range(len(mask_candidates_cfg))))
170-
skip_time_steps_cfg: Optional[int] = kwargs.get('skip_time_steps')
168+
skip_time_steps_cfg: int | None = kwargs.get('skip_time_steps')
171169

172170
# Parse selected masks
173-
selected_masks_cfg: List[List[int]] = []
171+
selected_masks_cfg: list[list[int]] = []
174172
for index in mask_selected_cfg:
175173
mask = mask_candidates_cfg[index]
176174
masks_list = [int(x) for x in mask.split(',')]
@@ -187,7 +185,7 @@ def configure_sta(mode: str = 'STA_searching',
187185
selected_masks_cfg)
188186

189187
# Add full attention mask for specific cases
190-
full_attention_mask_cfg: Optional[List[int]] = kwargs.get(
188+
full_attention_mask_cfg: list[int] | None = kwargs.get(
191189
'full_attention_mask')
192190
if full_attention_mask_cfg is not None:
193191
selected_masks_cfg.append(full_attention_mask_cfg)
@@ -227,7 +225,7 @@ def configure_sta(mode: str = 'STA_searching',
227225

228226
else: # STA_inference
229227
# Get parameters with defaults
230-
load_path: Optional[str] = kwargs.get(
228+
load_path: str | None = kwargs.get(
231229
'load_path', "mask_candidates/mask_strategy.json")
232230
if load_path is None:
233231
raise ValueError("load_path is required for STA_inference mode")
@@ -248,9 +246,9 @@ def configure_sta(mode: str = 'STA_searching',
248246
# Helper functions
249247

250248

251-
def read_specific_json_files(folder_path: str) -> List[Dict[str, Any]]:
249+
def read_specific_json_files(folder_path: str) -> list[dict[str, Any]]:
252250
"""Read and parse JSON files containing mask search results."""
253-
json_contents: List[Dict[str, Any]] = []
251+
json_contents: list[dict[str, Any]] = []
254252

255253
# List files only in the current directory (no walk)
256254
files = os.listdir(folder_path)
@@ -268,11 +266,11 @@ def read_specific_json_files(folder_path: str) -> List[Dict[str, Any]]:
268266

269267

270268
def average_head_losses(
271-
results: List[Dict[str, Any]],
272-
selected_masks: List[List[int]]) -> Dict[str, Dict[str, np.ndarray]]:
269+
results: list[dict[str, Any]],
270+
selected_masks: list[list[int]]) -> dict[str, dict[str, np.ndarray]]:
273271
"""Average losses across all prompts for each mask strategy."""
274272
# Initialize a dictionary to store the averaged results
275-
averaged_losses: Dict[str, Dict[str, np.ndarray]] = {}
273+
averaged_losses: dict[str, dict[str, np.ndarray]] = {}
276274
loss_type = 'L2_loss'
277275
# Get all loss types (e.g., 'L2_loss')
278276
averaged_losses[loss_type] = {}
@@ -294,14 +292,14 @@ def average_head_losses(
294292

295293

296294
def select_best_mask_strategy(
297-
averaged_results: Dict[str, Dict[str, np.ndarray]],
298-
selected_masks: List[List[int]],
295+
averaged_results: dict[str, dict[str, np.ndarray]],
296+
selected_masks: list[list[int]],
299297
skip_time_steps: int = 12,
300298
timesteps: int = 50,
301299
head_num: int = 40
302-
) -> Tuple[Dict[str, List[int]], float, Dict[str, int]]:
300+
) -> tuple[dict[str, list[int]], float, dict[str, int]]:
303301
"""Select the best mask strategy for each head based on loss minimization."""
304-
best_mask_strategy: Dict[str, List[int]] = {}
302+
best_mask_strategy: dict[str, list[int]] = {}
305303
loss_type = 'L2_loss'
306304
# Get the shape of time steps and layers
307305
layers = len(averaged_results[loss_type][str(selected_masks[0])][0])
@@ -310,7 +308,7 @@ def select_best_mask_strategy(
310308
total_tokens = 0 # total number of masked tokens
311309
total_length = 0 # total sequence length
312310

313-
strategy_counts: Dict[str, int] = {
311+
strategy_counts: dict[str, int] = {
314312
str(strategy): 0
315313
for strategy in selected_masks
316314
}
@@ -352,22 +350,22 @@ def select_best_mask_strategy(
352350

353351

354352
def save_mask_search_results(
355-
mask_search_final_result: List[Dict[str, List[float]]],
353+
mask_search_final_result: list[dict[str, list[float]]],
356354
prompt: str,
357-
mask_strategies: List[str],
358-
output_dir: str = 'output/mask_search_result/') -> Optional[str]:
355+
mask_strategies: list[str],
356+
output_dir: str = 'output/mask_search_result/') -> str | None:
359357
if not mask_search_final_result:
360358
print("No mask search results to save")
361359
return None
362360

363361
# Create result dictionary with defaultdict for nested lists
364-
mask_search_dict: Dict[str, Dict[str, List[List[float]]]] = {
362+
mask_search_dict: dict[str, dict[str, list[list[float]]]] = {
365363
"L2_loss": defaultdict(list),
366364
"L1_loss": defaultdict(list)
367365
}
368366

369367
mask_selected = list(range(len(mask_strategies)))
370-
selected_masks: List[List[int]] = []
368+
selected_masks: list[list[int]] = []
371369
for index in mask_selected:
372370
mask = mask_strategies[index]
373371
masks_list = [int(x) for x in mask.split(',')]
@@ -377,7 +375,7 @@ def save_mask_search_results(
377375
for i, mask_strategy in enumerate(selected_masks):
378376
mask_strategy_str = str(mask_strategy)
379377
# Process L2 loss
380-
step_results: List[List[float]] = []
378+
step_results: list[list[float]] = []
381379
for step_data in mask_search_final_result:
382380
if isinstance(step_data, dict) and "L2_loss" in step_data:
383381
layer_losses = [float(loss) for loss in step_data["L2_loss"]]

fastvideo/v1/attention/backends/abstract.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
from abc import ABC, abstractmethod
55
from dataclasses import dataclass, fields
6-
from typing import (TYPE_CHECKING, Any, Dict, Generic, Optional, Protocol, Set,
7-
Type, TypeVar)
6+
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
87

98
if TYPE_CHECKING:
109
from fastvideo.v1.fastvideo_args import FastVideoArgs
@@ -27,12 +26,12 @@ def get_name() -> str:
2726

2827
@staticmethod
2928
@abstractmethod
30-
def get_impl_cls() -> Type["AttentionImpl"]:
29+
def get_impl_cls() -> type["AttentionImpl"]:
3130
raise NotImplementedError
3231

3332
@staticmethod
3433
@abstractmethod
35-
def get_metadata_cls() -> Type["AttentionMetadata"]:
34+
def get_metadata_cls() -> type["AttentionMetadata"]:
3635
raise NotImplementedError
3736

3837
# @staticmethod
@@ -46,7 +45,7 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
4645

4746
@staticmethod
4847
@abstractmethod
49-
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
48+
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
5049
raise NotImplementedError
5150

5251

@@ -57,8 +56,7 @@ class AttentionMetadata:
5756
current_timestep: int
5857

5958
def asdict_zerocopy(self,
60-
skip_fields: Optional[Set[str]] = None
61-
) -> Dict[str, Any]:
59+
skip_fields: set[str] | None = None) -> dict[str, Any]:
6260
"""Similar to dataclasses.asdict, but avoids deepcopying."""
6361
if skip_fields is None:
6462
skip_fields = set()
@@ -124,7 +122,7 @@ def __init__(
124122
head_size: int,
125123
softmax_scale: float,
126124
causal: bool = False,
127-
num_kv_heads: Optional[int] = None,
125+
num_kv_heads: int | None = None,
128126
prefix: str = "",
129127
**extra_impl_args,
130128
) -> None:

0 commit comments

Comments
 (0)