Skip to content

Commit 229c4b3

Browse files
committed
add from_pretrained/save_pretrained for guider
1 parent 0a4819a commit 229c4b3

File tree

1 file changed

+89
-3
lines changed

1 file changed

+89
-3
lines changed

src/diffusers/guiders/guider_utils.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
15+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, Optional
1616

1717
import torch
18+
from huggingface_hub.utils import validate_hf_hub_args
19+
from typing_extensions import Self
20+
21+
import os
1822

1923
from ..configuration_utils import ConfigMixin
20-
from ..utils import get_logger
24+
from ..utils import PushToHubMixin, get_logger
25+
2126

2227

2328
if TYPE_CHECKING:
@@ -30,7 +35,7 @@
3035
logger = get_logger(__name__) # pylint: disable=invalid-name
3136

3237

33-
class BaseGuidance(ConfigMixin):
38+
class BaseGuidance(ConfigMixin, PushToHubMixin):
3439
r"""Base class providing the skeleton for implementing guidance techniques."""
3540

3641
config_name = GUIDER_CONFIG_NAME
@@ -198,6 +203,87 @@ def _prepare_batch(
198203
data_batch[cls._identifier_key] = identifier
199204
return BlockState(**data_batch)
200205

206+
@classmethod
207+
@validate_hf_hub_args
208+
def from_pretrained(
209+
cls,
210+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
211+
subfolder: Optional[str] = None,
212+
return_unused_kwargs=False,
213+
**kwargs,
214+
) -> Self:
215+
r"""
216+
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
217+
218+
Parameters:
219+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
220+
Can be either:
221+
222+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
223+
the Hub.
224+
- A path to a *directory* (for example `./my_model_directory`) containing the guider
225+
configuration saved with [`~BaseGuidance.save_pretrained`].
226+
subfolder (`str`, *optional*):
227+
The subfolder location of a model file within a larger model repository on the Hub or locally.
228+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
229+
Whether kwargs that are not consumed by the Python class should be returned or not.
230+
cache_dir (`Union[str, os.PathLike]`, *optional*):
231+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
232+
is not used.
233+
force_download (`bool`, *optional*, defaults to `False`):
234+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
235+
cached versions if they exist.
236+
237+
proxies (`Dict[str, str]`, *optional*):
238+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
239+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
240+
output_loading_info(`bool`, *optional*, defaults to `False`):
241+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
242+
local_files_only(`bool`, *optional*, defaults to `False`):
243+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
244+
won't be downloaded from the Hub.
245+
token (`str` or *bool*, *optional*):
246+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
247+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
248+
revision (`str`, *optional*, defaults to `"main"`):
249+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
250+
allowed by Git.
251+
252+
<Tip>
253+
254+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
255+
`huggingface-cli login`. You can also activate the special
256+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
257+
firewalled environment.
258+
259+
</Tip>
260+
261+
"""
262+
config, kwargs, commit_hash = cls.load_config(
263+
pretrained_model_name_or_path=pretrained_model_name_or_path,
264+
subfolder=subfolder,
265+
return_unused_kwargs=True,
266+
return_commit_hash=True,
267+
**kwargs,
268+
)
269+
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
270+
271+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
272+
"""
273+
Save a guider configuration object to a directory so that it can be reloaded using the
274+
[`~BaseGuidance.from_pretrained`] class method.
275+
276+
Args:
277+
save_directory (`str` or `os.PathLike`):
278+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
279+
push_to_hub (`bool`, *optional*, defaults to `False`):
280+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
281+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
282+
namespace).
283+
kwargs (`Dict[str, Any]`, *optional*):
284+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
285+
"""
286+
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
201287

202288
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
203289
r"""

0 commit comments

Comments
 (0)