Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .github/workflows/pr_dependency_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.10"
- name: Install dependencies
run: |
pip install -e .
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.10"
- name: Install dependencies
run: |
pip install --upgrade pip
Expand All @@ -55,7 +55,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.10"
- name: Install dependencies
run: |
pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pr_tests_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.10"
- name: Install dependencies
run: |
pip install --upgrade pip
Expand All @@ -56,7 +56,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.10"
- name: Install dependencies
run: |
pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_torch_dependency_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.10"
- name: Install dependencies
run: |
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi_publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.10"

- name: Install dependencies
run: |
Expand Down
146 changes: 146 additions & 0 deletions scripts/remove_typing_builtin_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
"""
Remove lower-case built-in generics imported from `typing`.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Iterable, Iterator, Sequence


try:
import libcst as cst
except ImportError as exc: # pragma: no cover - dependency guard
raise SystemExit("This script requires `libcst`. Install it via `pip install libcst` and retry.") from exc


BUILTIN_TYPING_NAMES = frozenset({"callable", "dict", "frozenset", "list", "set", "tuple", "type"})


class TypingBuiltinImportRemover(cst.CSTTransformer):
def __init__(self) -> None:
self.changed = False
self.removed: list[str] = []
self.warnings: list[str] = []

def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.BaseStatement:
module_name = self._module_name(updated_node.module)
if module_name != "typing":
return updated_node

names = updated_node.names
if isinstance(names, cst.ImportStar):
self.warnings.append("encountered `from typing import *` (skipped)")
return updated_node

new_aliases = []
removed_here: list[str] = []
for alias in names:
if isinstance(alias, cst.ImportStar):
self.warnings.append("encountered `from typing import *` (skipped)")
return updated_node
if not isinstance(alias.name, cst.Name):
new_aliases.append(alias)
continue
imported_name = alias.name.value
if imported_name in BUILTIN_TYPING_NAMES:
removed_here.append(imported_name)
continue
new_aliases.append(alias)

if not removed_here:
return updated_node

self.changed = True
self.removed.extend(removed_here)

if not new_aliases:
return cst.RemoveFromParent()
# Ensure trailing commas are removed.
formatted_aliases = []
for alias in new_aliases:
if alias.comma is not None and alias is new_aliases[-1]:
formatted_aliases.append(alias.with_changes(comma=None))
else:
formatted_aliases.append(alias)

return updated_node.with_changes(names=tuple(formatted_aliases))

def _module_name(self, node: cst.BaseExpression | None) -> str | None:
if node is None:
return None
if isinstance(node, cst.Name):
return node.value
if isinstance(node, cst.Attribute):
prefix = self._module_name(node.value)
if prefix is None:
return node.attr.value
return f"{prefix}.{node.attr.value}"
return None


def iter_python_files(paths: Iterable[Path]) -> Iterator[Path]:
for path in paths:
if path.is_dir():
yield from (p for p in path.rglob("*.py") if not p.name.startswith("."))
yield from (p for p in path.rglob("*.pyi") if not p.name.startswith("."))
elif path.suffix in {".py", ".pyi"}:
yield path


def process_file(path: Path, dry_run: bool) -> tuple[bool, TypingBuiltinImportRemover]:
source = path.read_text(encoding="utf-8")
module = cst.parse_module(source)
transformer = TypingBuiltinImportRemover()
updated = module.visit(transformer)

if not transformer.changed or source == updated.code:
return False, transformer

if not dry_run:
path.write_text(updated.code, encoding="utf-8")
return True, transformer


def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Remove lower-case built-in generics imported from typing.")
parser.add_argument(
"paths",
nargs="*",
type=Path,
default=[Path("src")],
help="Files or directories to rewrite (default: src).",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Only report files that would change without writing them.",
)
args = parser.parse_args(argv)

files = sorted(set(iter_python_files(args.paths)))
if not files:
print("No Python files matched the provided paths.", file=sys.stderr)
return 1

changed_any = False
for path in files:
changed, transformer = process_file(path, dry_run=args.dry_run)
if changed:
changed_any = True
action = "Would update" if args.dry_run else "Updated"
removed = ", ".join(sorted(set(transformer.removed)))
print(f"{action}: {path} (removed typing imports: {removed})")
for warning in transformer.warnings:
print(f"Warning: {path}: {warning}", file=sys.stderr)

if not changed_any:
print("No changes needed.")
return 0


if __name__ == "__main__":
raise SystemExit(main())
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
"pytest",
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
"python>=3.9.0",
"ruff==0.9.10",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
Expand Down Expand Up @@ -287,7 +287,7 @@ def run(self):
packages=find_packages("src"),
package_data={"diffusers": ["py.typed"]},
include_package_data=True,
python_requires=">=3.8.0",
python_requires=">=3.9.0",
Copy link
Contributor

@tolgacangoz tolgacangoz Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

install_requires=list(install_requires),
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
Expand Down
24 changes: 12 additions & 12 deletions src/diffusers/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME
Expand Down Expand Up @@ -33,13 +33,13 @@ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")

@property
def tensor_inputs(self) -> List[str]:
def tensor_inputs(self) -> list[str]:
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")

def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> dict[str, Any]:
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")

def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)


Expand All @@ -49,14 +49,14 @@ class MultiPipelineCallbacks:
provides a unified interface for calling all of them.
"""

def __init__(self, callbacks: List[PipelineCallback]):
def __init__(self, callbacks: list[PipelineCallback]):
self.callbacks = callbacks

@property
def tensor_inputs(self) -> List[str]:
def tensor_inputs(self) -> list[str]:
return [input for callback in self.callbacks for input in callback.tensor_inputs]

def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
"""
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
"""
Expand All @@ -76,7 +76,7 @@ class SDCFGCutoffCallback(PipelineCallback):

tensor_inputs = ["prompt_embeds"]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

Expand Down Expand Up @@ -109,7 +109,7 @@ class SDXLCFGCutoffCallback(PipelineCallback):
"add_time_ids",
]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

Expand Down Expand Up @@ -152,7 +152,7 @@ class SDXLControlnetCFGCutoffCallback(PipelineCallback):
"image",
]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

Expand Down Expand Up @@ -195,7 +195,7 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):

tensor_inputs = []

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

Expand All @@ -219,7 +219,7 @@ class SD3CFGCutoffCallback(PipelineCallback):

tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

Expand Down
Loading
Loading