Skip to content

Commit 4e2c1f3

Browse files
Add config docs (#429)
* advance * finish * finish
1 parent 5e6417e commit 4e2c1f3

File tree

4 files changed

+84
-31
lines changed

4 files changed

+84
-31
lines changed

docs/source/api/configuration.mdx

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
1010
specific language governing permissions and limitations under the License.
1111
-->
1212

13-
# Models
13+
# Configuration
1414

15-
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
16-
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
17-
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
15+
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
16+
passed to the respective `__init__` methods in a JSON-configuration file.
1817

19-
## API
18+
TODO(PVP) - add example and better info here
2019

21-
Models should provide the `def forward` function and initialization of the model.
22-
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
23-
24-
## Examples
25-
26-
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
27-
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
28-
- TODO: mention VAE / SDE score estimation
20+
## ConfigMixin
21+
[[autodoc]] ConfigMixin
22+
- from_config
23+
- save_config

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
__version__ = "0.3.0.dev0"
1111

12+
from .configuration_utils import ConfigMixin
1213
from .modeling_utils import ModelMixin
1314
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
1415
from .onnx_utils import OnnxRuntimeModel

src/diffusers/configuration_utils.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,16 @@
3737

3838
class ConfigMixin:
3939
r"""
40-
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
41-
methods for loading/downloading/saving configurations.
42-
40+
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
41+
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
42+
- [`~ConfigMixin.from_config`]
43+
- [`~ConfigMixin.save_config`]
44+
45+
Class attributes:
46+
- **config_name** (`str`) -- A filename under which the config should stored when calling
47+
[`~ConfigMixin.save_config`] (should be overriden by parent class).
48+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
49+
overriden by parent class).
4350
"""
4451
config_name = None
4552
ignore_for_config = []
@@ -74,8 +81,6 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
7481
Args:
7582
save_directory (`str` or `os.PathLike`):
7683
Directory where the configuration JSON file will be saved (will be created if it does not exist).
77-
kwargs:
78-
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
7984
"""
8085
if os.path.isfile(save_directory):
8186
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -90,6 +95,64 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
9095

9196
@classmethod
9297
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
98+
r"""
99+
Instantiate a Python class from a pre-defined JSON-file.
100+
101+
Parameters:
102+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
103+
Can be either:
104+
105+
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
106+
organization name, like `google/ddpm-celebahq-256`.
107+
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
108+
`./my_model_directory/`.
109+
110+
cache_dir (`Union[str, os.PathLike]`, *optional*):
111+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
112+
standard cache should not be used.
113+
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
114+
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
115+
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
116+
checkpoint with 3 labels).
117+
force_download (`bool`, *optional*, defaults to `False`):
118+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
119+
cached versions if they exist.
120+
resume_download (`bool`, *optional*, defaults to `False`):
121+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
122+
file exists.
123+
proxies (`Dict[str, str]`, *optional*):
124+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
125+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
126+
output_loading_info(`bool`, *optional*, defaults to `False`):
127+
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
128+
local_files_only(`bool`, *optional*, defaults to `False`):
129+
Whether or not to only look at local files (i.e., do not try to download the model).
130+
use_auth_token (`str` or *bool*, *optional*):
131+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
132+
when running `transformers-cli login` (stored in `~/.huggingface`).
133+
revision (`str`, *optional*, defaults to `"main"`):
134+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
135+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
136+
identifier allowed by git.
137+
mirror (`str`, *optional*):
138+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
139+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
140+
Please refer to the mirror site for more information.
141+
142+
<Tip>
143+
144+
Passing `use_auth_token=True`` is required when you want to use a private model.
145+
146+
</Tip>
147+
148+
<Tip>
149+
150+
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
151+
use this method in a firewalled environment.
152+
153+
</Tip>
154+
155+
"""
93156
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
94157

95158
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
@@ -298,10 +361,10 @@ def __setitem__(self, name, value):
298361

299362

300363
def register_to_config(init):
301-
"""
302-
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
303-
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
304-
registered in the config, use the `ignore_for_config` class variable
364+
r"""
365+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
366+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
367+
shouldn't be registered in the config, use the `ignore_for_config` class variable
305368
306369
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
307370
"""

src/diffusers/modeling_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ class ModelMixin(torch.nn.Module):
119119
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
120120
and saving models.
121121
122-
Class attributes:
123-
124122
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
125123
[`~modeling_utils.ModelMixin.save_pretrained`].
126124
"""
@@ -200,10 +198,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
200198
Can be either:
201199
202200
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
203-
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
204-
user or organization name, like `dbmdz/bert-base-german-cased`.
205-
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
206-
e.g., `./my_model_directory/`.
201+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
202+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
203+
`./my_model_directory/`.
207204
208205
cache_dir (`Union[str, os.PathLike]`, *optional*):
209206
Path to a directory in which a downloaded pretrained model configuration should be cached if the
@@ -236,9 +233,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
236233
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
237234
Please refer to the mirror site for more information.
238235
239-
kwargs (remaining dictionary of keyword arguments, *optional*):
240-
Can be used to update the [`ConfigMixin`] of the model (after it being loaded).
241-
242236
<Tip>
243237
244238
Passing `use_auth_token=True`` is required when you want to use a private model.

0 commit comments

Comments
 (0)