Skip to content

Commit 8faa822

Browse files
Allow to set config params directly in init (#1419)
* fix * fix deprecated kwargs logic * add tests * finish
1 parent 86aa747 commit 8faa822

13 files changed

+103
-11
lines changed

src/diffusers/configuration_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,18 @@ class ConfigMixin:
8080
- **config_name** (`str`) -- A filename under which the config should stored when calling
8181
[`~ConfigMixin.save_config`] (should be overridden by parent class).
8282
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
83-
overridden by parent class).
84-
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
85-
class).
83+
overridden by subclass).
84+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
85+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
86+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
87+
subclass).
8688
"""
8789
config_name = None
8890
ignore_for_config = []
8991
has_compatibles = False
9092

93+
_deprecated_kwargs = []
94+
9195
def register_to_config(self, **kwargs):
9296
if self.config_name is None:
9397
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
@@ -195,10 +199,10 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
195199
if "dtype" in unused_kwargs:
196200
init_dict["dtype"] = unused_kwargs.pop("dtype")
197201

198-
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
199-
deprecate("remove this", "0.10.0", "remove")
200-
predict_epsilon = unused_kwargs.pop("predict_epsilon")
201-
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"
202+
# add possible deprecated kwargs
203+
for deprecated_kwarg in cls._deprecated_kwargs:
204+
if deprecated_kwarg in unused_kwargs:
205+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
202206

203207
# Return model and optionally state and/or unused_kwargs
204208
model = cls(**init_dict)
@@ -526,7 +530,6 @@ def inner_init(self, *args, **kwargs):
526530
# Ignore private kwargs in the init.
527531
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
528532
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
529-
init(self, *args, **init_kwargs)
530533
if not isinstance(self, ConfigMixin):
531534
raise RuntimeError(
532535
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
@@ -553,6 +556,7 @@ def inner_init(self, *args, **kwargs):
553556
)
554557
new_kwargs = {**config_init_kwargs, **new_kwargs}
555558
getattr(self, "register_to_config")(**new_kwargs)
559+
init(self, *args, **init_kwargs)
556560

557561
return inner_init
558562

src/diffusers/models/unet_2d_blocks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def __init__(
254254
attn_num_head_channels=1,
255255
attention_type="default",
256256
output_scale_factor=1.0,
257-
**kwargs,
258257
):
259258
super().__init__()
260259

@@ -336,7 +335,6 @@ def __init__(
336335
cross_attention_dim=1280,
337336
dual_cross_attention=False,
338337
use_linear_projection=False,
339-
**kwargs,
340338
):
341339
super().__init__()
342340

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,6 @@ def __init__(
10391039
cross_attention_dim=1280,
10401040
dual_cross_attention=False,
10411041
use_linear_projection=False,
1042-
**kwargs,
10431042
):
10441043
super().__init__()
10451044

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
113113
"""
114114

115115
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
116+
_deprecated_kwargs = ["predict_epsilon"]
116117

117118
@register_to_config
118119
def __init__(

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
116116
"""
117117

118118
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
119+
_deprecated_kwargs = ["predict_epsilon"]
119120

120121
@property
121122
def has_state(self):

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
105105
"""
106106

107107
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
108+
_deprecated_kwargs = ["predict_epsilon"]
108109

109110
@register_to_config
110111
def __init__(

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
109109
"""
110110

111111
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
112+
_deprecated_kwargs = ["predict_epsilon"]
112113

113114
@property
114115
def has_state(self):

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
117117
"""
118118

119119
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
120+
_deprecated_kwargs = ["predict_epsilon"]
120121

121122
@register_to_config
122123
def __init__(

src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
149149
"""
150150

151151
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
152+
_deprecated_kwargs = ["predict_epsilon"]
152153

153154
@property
154155
def has_state(self):

tests/test_modeling_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,23 @@ def test_enable_disable_gradient_checkpointing(self):
265265
# check disable works
266266
model.disable_gradient_checkpointing()
267267
self.assertFalse(model.is_gradient_checkpointing)
268+
269+
def test_deprecated_kwargs(self):
270+
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
271+
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
272+
273+
if has_kwarg_in_model_class and not has_deprecated_kwarg:
274+
raise ValueError(
275+
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
276+
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
277+
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
278+
" [<deprecated_argument>]`"
279+
)
280+
281+
if not has_kwarg_in_model_class and has_deprecated_kwarg:
282+
raise ValueError(
283+
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
284+
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
285+
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
286+
" from `_deprecated_kwargs = [<deprecated_argument>]`"
287+
)

0 commit comments

Comments
 (0)