Skip to content

Commit 577468c

Browse files
authored
fix: enable invalid assignment ty checks (#553)
* fix: enable invalid assignment ty checks * fix: ruff errors * chore: undo type checking for notebooks * fix: type check fixes for ty v0.0.20 * fix: ruff & numpy docstring check incompatibility
1 parent 2b63353 commit 577468c

35 files changed

+83
-73
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ exclude = ["test_*.py"]
33
line-length = 121
44

55
[tool.ruff.lint]
6-
ignore = ["ANN002", "ANN003", "ANN401", "C901", "D100", "D104", "D401", "D406", "C408", "PTH123", "N813"]
6+
ignore = ["ANN002", "ANN003", "ANN401", "C901", "D100", "D104", "D401", "D406", "D420", "C408", "PTH123", "N813"]
77
select = ["A", "C", "CPY", "D", "E", "ERA", "F", "FIX", "I", "N", "SIM", "T20", "W", "FA", "PTH"]
88
preview = true # enable preview features for copyright checking
99

@@ -22,7 +22,6 @@ author = "- Pruna AI GmbH"
2222
[tool.ty.rules]
2323
# Ignore rules that are stricter than mypy for transition period
2424
unresolved-import = "ignore"
25-
invalid-assignment = "ignore" # mypy is more permissive with Any assignments
2625
call-non-callable = "ignore" # mypy allows more dynamic method calls
2726
index-out-of-bounds = "ignore" # mypy is more permissive with tuple indexing
2827
unresolved-attribute = "ignore" # mypy is more permissive with module attributes
@@ -38,6 +37,7 @@ possibly-missing-attribute = "ignore"
3837
missing-argument = "ignore"
3938
unused-type-ignore-comment = "ignore"
4039

40+
4141
[tool.coverage.run]
4242
source = ["src/pruna"]
4343

src/pruna/algorithms/base/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import inspect
1616
import logging
1717
import pkgutil
18-
from typing import Any, Callable, Dict
18+
from typing import Any, Dict
1919

2020
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
2121
from pruna.algorithms.base.tags import AlgorithmTag
@@ -28,7 +28,7 @@ class AlgorithmRegistry:
2828
The registry is a dictionary that maps algorithm names to algorithm instances.
2929
"""
3030

31-
_registry: Dict[str, Callable[..., Any]] = {}
31+
_registry: Dict[str, PrunaAlgorithmBase] = {}
3232

3333
@classmethod
3434
def discover_algorithms(cls, algorithms_pkg: Any) -> None:

src/pruna/algorithms/c_translate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ConfigSpace import OrdinalHyperparameter
2727
from transformers import (
2828
AutomaticSpeechRecognitionPipeline,
29+
PretrainedConfig,
2930
WhisperConfig,
3031
)
3132
from transformers.modeling_utils import PreTrainedModel
@@ -63,7 +64,7 @@ class CTranslate(PrunaAlgorithmBase):
6364
"""
6465

6566
algorithm_name: str = "c_translate"
66-
group_tags: list[str] = [tags.COMPILER]
67+
group_tags: list[tags] = [tags.COMPILER]
6768
save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.save_before_apply
6869
references = {"GitHub": "https://github.com/OpenNMT/CTranslate2"}
6970
tokenizer_required: bool = True
@@ -345,6 +346,7 @@ def __init__(self, generator: PreTrainedModel, output_dir: str, tokenizer: PreTr
345346
self.output_dir = output_dir
346347
self.task = "generation"
347348
self.tokenizer = tokenizer
349+
self.config: PretrainedConfig | None = None
348350

349351
def __getattr__(self, name: str) -> Any:
350352
"""
@@ -416,6 +418,7 @@ def __init__(self, translator: PreTrainedModel, output_dir: str, tokenizer: PreT
416418
self.output_dir = output_dir
417419
self.task = "translation"
418420
self.tokenizer = tokenizer
421+
self.config: PretrainedConfig | None = None
419422

420423
def __getattr__(self, name: str) -> Any:
421424
"""
@@ -499,6 +502,7 @@ def __init__(self, whisper: Whisper, output_dir: str, processor: ProcessorMixin)
499502
self.processor = processor
500503
self.language = None
501504
self.prompt = None
505+
self.config: PretrainedConfig | None = None
502506

503507
def __getattr__(self, name: str) -> Any:
504508
"""

src/pruna/algorithms/deepcache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class DeepCache(PrunaAlgorithmBase):
3333
"""
3434

3535
algorithm_name: str = "deepcache"
36-
group_tags: list[str] = [tags.CACHER]
36+
group_tags: list[tags] = [tags.CACHER]
3737
save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.reapply
3838
references: dict[str, str] = {
3939
"GitHub": "https://github.com/horseee/DeepCache",

src/pruna/algorithms/fastercache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class FasterCache(PrunaAlgorithmBase):
4747
"""
4848

4949
algorithm_name: str = "fastercache"
50-
group_tags: list[str] = [tags.CACHER]
50+
group_tags: list[tags] = [tags.CACHER]
5151
save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.reapply
5252
references: dict[str, str] = {
5353
"GitHub": "https://github.com/Vchitect/FasterCache",

src/pruna/algorithms/flash_attn3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class FlashAttn3(PrunaAlgorithmBase):
4141
"""
4242

4343
algorithm_name: str = "flash_attn3"
44-
group_tags: list[str] = [tags.KERNEL]
44+
group_tags: list[tags] = [tags.KERNEL]
4545
save_fn = SAVE_FUNCTIONS.reapply
4646
references: dict[str, str] = {
4747
"GitHub": "https://github.com/Dao-AILab/flash-attention",

src/pruna/algorithms/fora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class FORA(PrunaAlgorithmBase):
3737
"""
3838

3939
algorithm_name: str = "fora"
40-
group_tags: list[str] = [tags.CACHER]
40+
group_tags: list[tags] = [tags.CACHER]
4141
save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.reapply
4242
references: dict[str, str] = {"Paper": "https://arxiv.org/abs/2407.01425"}
4343
tokenizer_required: bool = False
@@ -163,8 +163,8 @@ def __init__(self, pipe: Any, interval: int, start_step: int, backbone_calls_per
163163
self.single_stream_blocks_forward: Dict[int, Callable] = {}
164164

165165
# Use seperate caches for the two different transformer block types
166-
self.double_stream_blocks_cache: Dict[int, Tuple[Any, Any]] = {}
167-
self.single_stream_blocks_cache: Dict[int, Any] = {}
166+
self.double_stream_blocks_cache: Dict[Tuple[int, int], Tuple[Any, Any]] = {}
167+
self.single_stream_blocks_cache: Dict[Tuple[int, int], Any] = {}
168168

169169
def get_cache_schedule(self, num_steps: int) -> list[int]:
170170
"""

src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,24 @@
1414

1515
from __future__ import annotations
1616

17+
import contextlib
1718
import functools
19+
import pathlib
1820
import random
1921
from typing import Any, List, Literal
2022

2123
import pytorch_lightning as pl
2224
import torch
23-
from diffusers.optimization import get_scheduler
24-
from diffusers.utils import BaseOutput
25-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
26-
from pytorch_lightning.utilities.seed import isolate_rng
27-
28-
try:
29-
from bitsandbytes.optim import AdamW8bit # type: ignore[import-untyped]
30-
except ImportError:
31-
AdamW8bit = None
32-
33-
import pathlib
34-
3525
from ConfigSpace import (
3626
CategoricalHyperparameter,
3727
Constant,
3828
UniformFloatHyperparameter,
3929
UniformIntegerHyperparameter,
4030
)
31+
from diffusers.optimization import get_scheduler
32+
from diffusers.utils import BaseOutput
33+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
34+
from pytorch_lightning.utilities.seed import isolate_rng
4135

4236
from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner
4337
from pruna.algorithms.global_utils.recovery.finetuners.diffusers import utils
@@ -56,6 +50,10 @@
5650
from pruna.engine.utils import get_device, get_device_type
5751
from pruna.logging.logger import pruna_logger
5852

53+
AdamW8bit: type[Any] | None = None
54+
with contextlib.suppress(ImportError):
55+
from bitsandbytes.optim import AdamW8bit
56+
5957

6058
class TextToImageDistiller(PrunaFinetuner):
6159
"""Distiller for text-to-image models."""

src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,20 @@
1414

1515
from __future__ import annotations
1616

17+
import contextlib
1718
from pathlib import Path
1819
from typing import Any, List, Literal, Tuple
1920

2021
import pytorch_lightning as pl
2122
import torch
22-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
23-
from pytorch_lightning.utilities.seed import isolate_rng
24-
25-
try:
26-
from bitsandbytes.optim import AdamW8bit # type: ignore[import-untyped]
27-
except ImportError:
28-
AdamW8bit = None
29-
3023
from ConfigSpace import (
3124
CategoricalHyperparameter,
3225
Constant,
3326
UniformFloatHyperparameter,
3427
UniformIntegerHyperparameter,
3528
)
29+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
30+
from pytorch_lightning.utilities.seed import isolate_rng
3631

3732
from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner
3833
from pruna.algorithms.global_utils.recovery.finetuners.diffusers import (
@@ -49,6 +44,10 @@
4944
from pruna.config.smash_config import SmashConfigPrefixWrapper
5045
from pruna.logging.logger import pruna_logger
5146

47+
AdamW8bit: type[Any] | None = None
48+
with contextlib.suppress(ImportError):
49+
from bitsandbytes.optim import AdamW8bit
50+
5251

5352
class TextToImageFinetuner(PrunaFinetuner):
5453
"""Finetuner for text-to-image models."""

src/pruna/algorithms/gptq_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class GPTQ(PrunaAlgorithmBase):
3838
"""
3939

4040
algorithm_name: str = "gptq"
41-
group_tags: list[str] = [tags.QUANTIZER]
41+
group_tags: list[tags] = [tags.QUANTIZER]
4242
references: dict[str, str] = {"GitHub": "https://github.com/ModelCloud/GPTQModel"}
4343
save_fn: None = None
4444
tokenizer_required: bool = True

0 commit comments

Comments
 (0)