Skip to content

Commit ab078c5

Browse files
fix: ty invalid-argument-type (#554)
* fix: ty invalid-argument-type --------- Co-authored-by: Gaspar Rochette <gaspar.rochette@pruna.ai>
1 parent 577468c commit ab078c5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+309
-230
lines changed

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ index-out-of-bounds = "ignore" # mypy is more permissive with tuple indexing
2727
unresolved-attribute = "ignore" # mypy is more permissive with module attributes
2828
redundant-cast = "ignore" # mypy doesn't warn about redundant casts
2929
unsupported-operator = "ignore" # mypy supports | syntax with from __future__ import annotations
30-
invalid-argument-type = "ignore" # mypy is more permissive with argument types
3130
invalid-return-type = "ignore" # mypy is more permissive with return types
3231
invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults
3332
no-matching-overload = "ignore" # mypy is more permissive with overloads
@@ -197,7 +196,7 @@ dev = [
197196
"pre-commit",
198197
"twine",
199198
"pyc-wheel",
200-
"ruff",
199+
"ruff>=0.15.3", # introduction of D420 rule
201200
"numpydoc>=1.9.0",
202201
"numpydoc-validation",
203202
"pytest",

src/pruna/algorithms/c_translate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def get_hyperparameters(self) -> list:
9898
"weight_bits",
9999
sequence=[8, 16],
100100
default_value=16,
101-
meta=dict(desc="Sets the number of bits to use for weight quantization."),
101+
meta={"desc": "Sets the number of bits to use for weight quantization."},
102102
),
103103
]
104104

@@ -392,7 +392,7 @@ def __call__(
392392
The generated sequence.
393393
"""
394394
if type(x) is dict or isinstance(x, transformers.tokenization_utils_base.BatchEncoding):
395-
x_tensor = x["input_ids"]
395+
x_tensor = x["input_ids"] # type: ignore[invalid-argument-type]
396396
else:
397397
x_tensor = x
398398
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
@@ -468,7 +468,7 @@ def __call__(
468468
if "max_length" in kwargs:
469469
max_decoding_length = kwargs["max_length"]
470470
if type(x) is dict or isinstance(x, transformers.tokenization_utils_base.BatchEncoding):
471-
x_tensor = x["input_ids"]
471+
x_tensor = x["input_ids"] # type: ignore[invalid-argument-type]
472472
else:
473473
x_tensor = x
474474
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]

src/pruna/algorithms/deepcache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def get_hyperparameters(self) -> list:
8080
"interval",
8181
sequence=[1, 2, 3, 4, 5],
8282
default_value=2,
83-
meta=dict(
84-
desc="Interval at which to cache - 1 disables caching. Higher is faster but might affect quality."
85-
),
83+
meta={
84+
"desc": "Interval at which to cache - 1 disables caching. Higher is faster but might affect quality."
85+
},
8686
),
8787
]
8888

src/pruna/algorithms/denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get_hyperparameters(self) -> list:
9797
upper=1.0,
9898
default_value=0.02,
9999
log=False,
100-
meta=dict(desc="Strength of the denoising/refinement. Lower values mean less change/more refinement."),
100+
meta={"desc": "Strength of the denoising/refinement. Lower values mean less change/more refinement."},
101101
),
102102
]
103103

src/pruna/algorithms/fastercache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def get_hyperparameters(self) -> list:
7474
"interval",
7575
sequence=[1, 2, 3, 4, 5],
7676
default_value=2,
77-
meta=dict(
78-
desc="Interval at which to cache spatial attention blocks - 1 disables caching."
77+
meta={
78+
"desc": "Interval at which to cache spatial attention blocks - 1 disables caching."
7979
"Higher is faster but might degrade quality."
80-
),
80+
},
8181
),
8282
]
8383

src/pruna/algorithms/flash_attn3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ def _flash_attention_3(
222222
)
223223
else:
224224
out, _, *_ = torch.ops.flash_attn_pruna._flash_attn_forward(
225-
q=query, k=key, v=value, softmax_scale=scale, causal=is_causal
225+
q=query, # type: ignore
226+
k=key, # type: ignore
227+
v=value, # type: ignore
228+
softmax_scale=scale, # type: ignore
229+
causal=is_causal, # type: ignore
226230
)
227231
return out
228232

@@ -286,7 +290,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # noqa: D105
286290
def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, kernel=None):
287291
# convert (B, H, S, D) → (B, S, H, D)
288292
q, k, v = [x.transpose(1, 2).contiguous() for x in (query, key, value)]
289-
out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale)
293+
out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore
290294
# back to (B, H, S, D) for the rest of the pipeline
291295
return out.transpose(1, 2)
292296

src/pruna/algorithms/fora.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,20 @@ def get_hyperparameters(self) -> list:
7070
"interval",
7171
sequence=range(1, 6),
7272
default_value=2,
73-
meta=dict(desc="Interval at which the outputs are computed. Higher is faster, but reduces quality."),
73+
meta={"desc": "Interval at which the outputs are computed. Higher is faster, but reduces quality."},
7474
),
7575
OrdinalHyperparameter(
7676
"start_step",
7777
sequence=range(11),
7878
default_value=2,
79-
meta=dict(desc="How many steps to wait before starting to cache."),
79+
meta={"desc": "How many steps to wait before starting to cache."},
8080
),
8181
OrdinalHyperparameter(
8282
"backbone_calls_per_step",
8383
sequence=range(1, 4),
8484
default_value=1,
85-
meta=dict(desc="Number of backbone forward passes per diffusion step (e.g., 2 for CFG)."),
85+
meta={"desc": "Number of backbone forward passes per diffusion step (e.g., 2 for CFG)."}
86+
8687
),
8788
]
8889

src/pruna/algorithms/global_utils/recovery/adapters/lora.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,30 +73,30 @@ def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
7373
"r",
7474
sequence=[4, 8, 16, 32, 64, 128],
7575
default_value=default_hyperparameters["r"],
76-
meta=dict(desc="Rank of the LoRA layers."),
76+
meta={"desc": "Rank of the LoRA layers."},
7777
),
7878
OrdinalHyperparameter(
7979
"alpha_r_ratio",
8080
sequence=[0.5, 1.0, 2.0],
8181
default_value=default_hyperparameters["alpha_r_ratio"],
82-
meta=dict(desc="Alpha/Rank ratio of the LoRA layers."),
82+
meta={"desc": "Alpha/Rank ratio of the LoRA layers."},
8383
),
8484
CategoricalHyperparameter(
8585
"target_modules",
8686
choices=[None, "all-linear"],
8787
default_value=default_hyperparameters["target_modules"],
88-
meta=dict(desc="Target modules for the LoRA layers."),
88+
meta={"desc": "Target modules for the LoRA layers."},
8989
),
9090
Constant(
9191
"dropout",
9292
default_hyperparameters["dropout"],
93-
meta=dict(desc="Dropout rate of the LoRA layers during training."),
93+
meta={"desc": "Dropout rate of the LoRA layers during training."},
9494
),
9595
CategoricalHyperparameter(
9696
"variant",
9797
choices=["lora", "pissa"],
9898
default_value=default_hyperparameters["variant"],
99-
meta=dict(desc="Variant of the LoRA adapter."),
99+
meta={"desc": "Variant of the LoRA adapter."},
100100
),
101101
]
102102

@@ -116,13 +116,13 @@ def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
116116
"r",
117117
sequence=[4, 8, 16, 32, 64, 128],
118118
default_value=default_hyperparameters["r"],
119-
meta=dict(desc="Rank of the LoRA layers."),
119+
meta={"desc": "Rank of the LoRA layers."},
120120
),
121121
OrdinalHyperparameter(
122122
"alpha_r_ratio",
123123
sequence=[0.5, 1.0, 2.0],
124124
default_value=default_hyperparameters["alpha_r_ratio"],
125-
meta=dict(desc="Alpha/Rank ratio of the LoRA layers."),
125+
meta={"desc": "Alpha/Rank ratio of the LoRA layers."},
126126
),
127127
Constant(
128128
"target_modules", default_hyperparameters["target_modules"]
@@ -132,7 +132,7 @@ def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
132132
"variant",
133133
choices=["lora", "pissa"],
134134
default_value=default_hyperparameters["variant"],
135-
meta=dict(desc="Variant of the LoRA adapter."),
135+
meta={"desc": "Variant of the LoRA adapter."},
136136
),
137137
]
138138
else:

src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,67 +95,67 @@ def get_hyperparameters(cls, **override_defaults) -> List:
9595
lower=0,
9696
upper=4096,
9797
default_value=numeric_defaults["training_batch_size"],
98-
meta=dict(desc="Number of steps from each diffusion process to use for distillation."),
98+
meta={"desc": "Number of steps from each diffusion process to use for distillation."},
9999
),
100100
UniformIntegerHyperparameter(
101101
"gradient_accumulation_steps",
102102
lower=1,
103103
upper=1024,
104104
default_value=numeric_defaults["gradient_accumulation_steps"],
105-
meta=dict(desc="Number of captions processed to estimate each gradient step."),
105+
meta={"desc": "Number of captions processed to estimate each gradient step."},
106106
),
107107
UniformIntegerHyperparameter(
108108
"num_epochs",
109109
lower=0,
110110
upper=4096,
111111
default_value=numeric_defaults["num_epochs"],
112-
meta=dict(desc="Number of epochs for distillation."),
112+
meta={"desc": "Number of epochs for distillation."},
113113
),
114114
UniformFloatHyperparameter(
115115
"validate_every_n_epoch",
116116
lower=0.0,
117117
upper=4096.0,
118118
default_value=numeric_defaults["validate_every_n_epoch"],
119-
meta=dict(
120-
desc="Number of epochs between each round of validation and model checkpointing. "
119+
meta={
120+
"desc": "Number of epochs between each round of validation and model checkpointing. "
121121
"If the value is between 0 and 1, validation will be performed multiple times per epoch, "
122122
"e.g. 1/8 will result in 8 validations per epoch."
123-
),
123+
},
124124
),
125125
UniformFloatHyperparameter(
126126
"learning_rate",
127127
lower=0.0,
128128
upper=1.0,
129129
default_value=numeric_defaults["learning_rate"],
130-
meta=dict(desc="Learning rate for distillation."),
130+
meta={"desc": "Learning rate for distillation."},
131131
),
132132
Constant("weight_decay", numeric_defaults["weight_decay"]),
133133
# report_to: for consistency with text-to-text-lora but wandb and tensorboard are not supported yet
134134
Constant("report_to", string_defaults["report_to"]),
135135
Boolean(
136136
"use_cpu_offloading",
137137
default=False,
138-
meta=dict(desc="Whether to use CPU offloading for distillation."),
138+
meta={"desc": "Whether to use CPU offloading for distillation."},
139139
),
140140
CategoricalHyperparameter(
141141
"optimizer",
142142
choices=["AdamW8bit", "AdamW", "Adam"],
143143
default_value=string_defaults["optimizer"],
144-
meta=dict(desc="Which optimizer to use for distillation."),
144+
meta={"desc": "Which optimizer to use for distillation."},
145145
),
146146
UniformFloatHyperparameter(
147147
"lr_decay",
148148
lower=0.0,
149149
upper=1.0,
150150
default_value=numeric_defaults["lr_decay"],
151-
meta=dict(desc="Learning rate decay, applied at each epoch."),
151+
meta={"desc": "Learning rate decay, applied at each epoch."},
152152
),
153153
UniformIntegerHyperparameter(
154154
"warmup_steps",
155155
lower=0,
156156
upper=2**14,
157157
default_value=numeric_defaults["warmup_steps"],
158-
meta=dict(desc="Number of warmup steps for the learning rate scheduler."),
158+
meta={"desc": "Number of warmup steps for the learning rate scheduler."},
159159
),
160160
]
161161

@@ -405,7 +405,7 @@ def distillation_forward(*args, **kwargs):
405405
output["sample"] if ("return_dict" in kwargs and kwargs["return_dict"]) else output[0]
406406
)
407407
loss = self.loss(latent_output, latent_targets[self.num_previous_steps])
408-
if is_training:
408+
if is_training and active_steps is not None:
409409
accumulation_normalized_loss = loss / (len(active_steps) * self.gradient_accumulation_steps)
410410
self.manual_backward(accumulation_normalized_loss)
411411
diffusion_step_losses.append(loss)

src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,53 +86,53 @@ def get_hyperparameters(cls, **override_defaults) -> List:
8686
lower=0,
8787
upper=4096,
8888
default_value=numeric_defaults["training_batch_size"],
89-
meta=dict(desc="Batch size for finetuning."),
89+
meta={"desc": "Batch size for finetuning."},
9090
),
9191
UniformIntegerHyperparameter(
9292
"gradient_accumulation_steps",
9393
lower=1,
9494
upper=1024,
9595
default_value=numeric_defaults["gradient_accumulation_steps"],
96-
meta=dict(desc="Number of gradient accumulation steps for finetuning."),
96+
meta={"desc": "Number of gradient accumulation steps for finetuning."},
9797
),
9898
UniformIntegerHyperparameter(
9999
"num_epochs",
100100
lower=0,
101101
upper=4096,
102102
default_value=numeric_defaults["num_epochs"],
103-
meta=dict(desc="Number of epochs for finetuning."),
103+
meta={"desc": "Number of epochs for finetuning."},
104104
),
105105
UniformFloatHyperparameter(
106106
"validate_every_n_epoch",
107107
lower=0.0,
108108
upper=4096.0,
109109
default_value=numeric_defaults["validate_every_n_epoch"],
110-
meta=dict(
111-
desc="Number of epochs between each round of validation and model checkpointing. "
110+
meta={
111+
"desc": "Number of epochs between each round of validation and model checkpointing. "
112112
"If the value is between 0 and 1, validation will be performed multiple times per epoch, "
113113
"e.g. 1/8 will result in 8 validations per epoch."
114-
),
114+
},
115115
),
116116
UniformFloatHyperparameter(
117117
"learning_rate",
118118
lower=0.0,
119119
upper=1.0,
120120
default_value=numeric_defaults["learning_rate"],
121-
meta=dict(desc="Learning rate for finetuning."),
121+
meta={"desc": "Learning rate for finetuning."},
122122
),
123123
Constant("weight_decay", numeric_defaults["weight_decay"]),
124124
# report_to: for consistency with text-to-text-lora but wandb and tensorboard are not supported yet
125125
Constant("report_to", string_defaults["report_to"]),
126126
Boolean(
127127
"use_cpu_offloading",
128128
default=False,
129-
meta=dict(desc="Whether to use CPU offloading for finetuning."),
129+
meta={"desc": "Whether to use CPU offloading for finetuning."},
130130
), # necessary for Flux in float16 on L40S GPU (48gb VRAM)
131131
CategoricalHyperparameter(
132132
"optimizer",
133133
choices=["AdamW8bit", "AdamW", "Adam"],
134134
default_value=string_defaults["optimizer"],
135-
meta=dict(desc="Which optimizer to use for finetuning."),
135+
meta={"desc": "Which optimizer to use for finetuning."},
136136
),
137137
]
138138

@@ -540,7 +540,6 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
540540
"""
541541
lr = self.learning_rate
542542
wd = self.weight_decay
543-
kwargs = {"eps": 1e-7} if self.trainer.precision in [16, "16", "16-true"] else {}
544543

545544
if self.optimizer_name == "AdamW8bit":
546545
optimizer_cls = AdamW8bit
@@ -553,4 +552,7 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
553552
optimizer_cls = getattr(torch.optim, self.optimizer_name)
554553
finetune_params = get_trainable_parameters(self.pipeline)
555554

556-
return optimizer_cls(finetune_params, lr=lr, weight_decay=wd, **kwargs)
555+
if self.trainer.precision in [16, "16", "16-true"]:
556+
return optimizer_cls(finetune_params, lr=lr, weight_decay=wd, eps=1e-7)
557+
else:
558+
return optimizer_cls(finetune_params, lr=lr, weight_decay=wd)

0 commit comments

Comments
 (0)