|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union |
| 15 | +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, Optional |
16 | 16 |
|
17 | 17 | import torch |
| 18 | +from huggingface_hub.utils import validate_hf_hub_args |
| 19 | +from typing_extensions import Self |
| 20 | + |
| 21 | +import os |
18 | 22 |
|
19 | 23 | from ..configuration_utils import ConfigMixin |
20 | | -from ..utils import get_logger |
| 24 | +from ..utils import PushToHubMixin, get_logger |
| 25 | + |
21 | 26 |
|
22 | 27 |
|
23 | 28 | if TYPE_CHECKING: |
|
30 | 35 | logger = get_logger(__name__) # pylint: disable=invalid-name |
31 | 36 |
|
32 | 37 |
|
33 | | -class BaseGuidance(ConfigMixin): |
| 38 | +class BaseGuidance(ConfigMixin, PushToHubMixin): |
34 | 39 | r"""Base class providing the skeleton for implementing guidance techniques.""" |
35 | 40 |
|
36 | 41 | config_name = GUIDER_CONFIG_NAME |
@@ -198,6 +203,87 @@ def _prepare_batch( |
198 | 203 | data_batch[cls._identifier_key] = identifier |
199 | 204 | return BlockState(**data_batch) |
200 | 205 |
|
| 206 | + @classmethod |
| 207 | + @validate_hf_hub_args |
| 208 | + def from_pretrained( |
| 209 | + cls, |
| 210 | + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
| 211 | + subfolder: Optional[str] = None, |
| 212 | + return_unused_kwargs=False, |
| 213 | + **kwargs, |
| 214 | + ) -> Self: |
| 215 | + r""" |
| 216 | + Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository. |
| 217 | +
|
| 218 | + Parameters: |
| 219 | + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): |
| 220 | + Can be either: |
| 221 | +
|
| 222 | + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on |
| 223 | + the Hub. |
| 224 | + - A path to a *directory* (for example `./my_model_directory`) containing the guider |
| 225 | + configuration saved with [`~BaseGuidance.save_pretrained`]. |
| 226 | + subfolder (`str`, *optional*): |
| 227 | + The subfolder location of a model file within a larger model repository on the Hub or locally. |
| 228 | + return_unused_kwargs (`bool`, *optional*, defaults to `False`): |
| 229 | + Whether kwargs that are not consumed by the Python class should be returned or not. |
| 230 | + cache_dir (`Union[str, os.PathLike]`, *optional*): |
| 231 | + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache |
| 232 | + is not used. |
| 233 | + force_download (`bool`, *optional*, defaults to `False`): |
| 234 | + Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
| 235 | + cached versions if they exist. |
| 236 | +
|
| 237 | + proxies (`Dict[str, str]`, *optional*): |
| 238 | + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', |
| 239 | + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. |
| 240 | + output_loading_info(`bool`, *optional*, defaults to `False`): |
| 241 | + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. |
| 242 | + local_files_only(`bool`, *optional*, defaults to `False`): |
| 243 | + Whether to only load local model weights and configuration files or not. If set to `True`, the model |
| 244 | + won't be downloaded from the Hub. |
| 245 | + token (`str` or *bool*, *optional*): |
| 246 | + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from |
| 247 | + `diffusers-cli login` (stored in `~/.huggingface`) is used. |
| 248 | + revision (`str`, *optional*, defaults to `"main"`): |
| 249 | + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier |
| 250 | + allowed by Git. |
| 251 | +
|
| 252 | + <Tip> |
| 253 | +
|
| 254 | + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with |
| 255 | + `huggingface-cli login`. You can also activate the special |
| 256 | + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a |
| 257 | + firewalled environment. |
| 258 | +
|
| 259 | + </Tip> |
| 260 | +
|
| 261 | + """ |
| 262 | + config, kwargs, commit_hash = cls.load_config( |
| 263 | + pretrained_model_name_or_path=pretrained_model_name_or_path, |
| 264 | + subfolder=subfolder, |
| 265 | + return_unused_kwargs=True, |
| 266 | + return_commit_hash=True, |
| 267 | + **kwargs, |
| 268 | + ) |
| 269 | + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) |
| 270 | + |
| 271 | + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): |
| 272 | + """ |
| 273 | + Save a guider configuration object to a directory so that it can be reloaded using the |
| 274 | + [`~BaseGuidance.from_pretrained`] class method. |
| 275 | +
|
| 276 | + Args: |
| 277 | + save_directory (`str` or `os.PathLike`): |
| 278 | + Directory where the configuration JSON file will be saved (will be created if it does not exist). |
| 279 | + push_to_hub (`bool`, *optional*, defaults to `False`): |
| 280 | + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the |
| 281 | + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your |
| 282 | + namespace). |
| 283 | + kwargs (`Dict[str, Any]`, *optional*): |
| 284 | + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. |
| 285 | + """ |
| 286 | + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) |
201 | 287 |
|
202 | 288 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
203 | 289 | r""" |
|
0 commit comments