-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[FEAT] DDUF format #10037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] DDUF format #10037
Changes from 2 commits
1fb86e3
0389333
2eeda25
cbee7cb
d840867
78135f1
7d2c7d5
e66c4d0
b14bffe
1cd5155
b8a43e7
cac988a
977baa3
81bd097
c4df147
d0a861c
1ec988f
5217712
6922226
3b0d84d
1bc953b
04ecf0e
9fff68a
ed6c727
17d50d1
59929a5
aa0d497
a793066
7602952
660d7c8
8358ef6
4e7d15a
cc75db3
63575af
1e5ebf5
53e100b
54991ac
1eb25dc
0cb1b98
5ec3951
1785eaa
7486016
73e81a5
ea0126d
3ebdcff
5943a60
021abf0
af2ca07
9d70b6c
47cb92c
c9734ab
27ebf9e
d5dbb5c
627aec0
f0e21a9
e9b7429
b8b699a
6003176
0e54b06
a026055
03e30b4
da48dcb
0fbea9a
6a163c7
67b617e
b40272e
ce237f3
454b9b9
366aa2f
6648995
21ae7ee
15d4569
f3a4ddc
a032025
faa0cac
0205cc8
9cda4c1
cd0734e
5037d39
c9e08da
02a368b
7bc9347
fff5954
da402da
9ebbf84
aaaa947
290b88d
ac420af
f62527f
c34ce42
c899fd0
bb1cff8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,6 +193,8 @@ def save_pretrained( | |
variant: Optional[str] = None, | ||
max_shard_size: Optional[Union[int, str]] = None, | ||
push_to_hub: bool = False, | ||
dduf_format: bool = False, | ||
dduf_filename: Optional[Union[str, os.PathLike]] = None, | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
**kwargs, | ||
): | ||
""" | ||
|
@@ -227,6 +229,9 @@ class implements both a save and loading method. The pipeline is easily reloaded | |
model_index_dict.pop("_module", None) | ||
model_index_dict.pop("_name_or_path", None) | ||
|
||
if dduf_format and dduf_filename is None: | ||
raise RuntimeError("You need set dduf_filename if you want to save your model in DDUF format.") | ||
|
||
if push_to_hub: | ||
commit_message = kwargs.pop("commit_message", None) | ||
private = kwargs.pop("private", False) | ||
|
@@ -301,6 +306,34 @@ def is_saveable_module(name, value): | |
|
||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) | ||
|
||
if dduf_format: | ||
import shutil | ||
import tarfile | ||
|
||
dduf_file_path = os.path.join(save_directory, dduf_filename) | ||
|
||
if os.path.isdir(dduf_file_path): | ||
logger.warning( | ||
f"Removing the existing folder {dduf_file_path} so that we can save the DDUF archive." | ||
) | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
shutil.rmtree(dduf_file_path) | ||
if ( | ||
os.path.exists(dduf_file_path) | ||
and os.path.isfile(dduf_file_path) | ||
and tarfile.is_tarfile(dduf_file_path) | ||
): | ||
# Open in append mode if the file exists | ||
mode = "a" | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
else: | ||
# Open in write mode to create it if it doesn't exist | ||
mode = "w:" | ||
with tarfile.open(dduf_file_path, mode) as tar: | ||
dir_to_archive = os.path.join(save_directory, pipeline_component_name) | ||
if os.path.isdir(dir_to_archive): | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
tar.add(dir_to_archive, arcname=os.path.basename(dir_to_archive)) | ||
# remove from save_directory after we added it to the archive | ||
shutil.rmtree(dir_to_archive) | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
julien-c marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# finally save the config | ||
self.save_config(save_directory) | ||
|
||
|
@@ -523,6 +556,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights | ||
saved using | ||
[`~DiffusionPipeline.save_pretrained`]. | ||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf archive or | ||
folder | ||
torch_dtype (`str` or `torch.dtype`, *optional*): | ||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the | ||
dtype is automatically derived from the model's weights. | ||
|
@@ -617,6 +652,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
variant (`str`, *optional*): | ||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when | ||
loading `from_flax`. | ||
dduf(`str`, *optional*): | ||
Load weights from the specified dduf archive or folder. | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
<Tip> | ||
|
||
|
@@ -666,6 +703,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
offload_state_dict = kwargs.pop("offload_state_dict", False) | ||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) | ||
variant = kwargs.pop("variant", None) | ||
dduf = kwargs.pop("dduf", None) | ||
use_safetensors = kwargs.pop("use_safetensors", None) | ||
use_onnx = kwargs.pop("use_onnx", None) | ||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) | ||
|
@@ -736,6 +774,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
custom_pipeline=custom_pipeline, | ||
custom_revision=custom_revision, | ||
variant=variant, | ||
dduf=dduf, | ||
load_connected_pipeline=load_connected_pipeline, | ||
**kwargs, | ||
) | ||
|
@@ -762,6 +801,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
# pop out "_ignore_files" as it is only needed for download | ||
config_dict.pop("_ignore_files", None) | ||
|
||
if dduf: | ||
import tarfile | ||
|
||
tar_file_path = os.path.join(cached_folder, dduf) | ||
extract_to = os.path.join(cached_folder, f"{dduf}_extracted") | ||
# if tar file, we need to extract the tarfile and remove it | ||
if os.path.isfile(tar_file_path): | ||
if tarfile.is_tarfile(tar_file_path): | ||
with tarfile.open(tar_file_path, "r") as tar: | ||
tar.extractall(extract_to) | ||
# remove tar archive to free memory | ||
os.remove(tar_file_path) | ||
# rename folder to match the name of the dduf archive | ||
os.rename(extract_to, tar_file_path) | ||
|
||
else: | ||
raise RuntimeError("The dduf path passed is not a tar archive") | ||
# udapte cached folder location as the dduf content is in a seperate folder | ||
cached_folder = tar_file_path | ||
|
||
# 2. Define which model components should load variants | ||
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders. | ||
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` | ||
|
@@ -1227,6 +1285,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |
variant (`str`, *optional*): | ||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when | ||
loading `from_flax`. | ||
dduf(`str`, *optional*): | ||
Load weights from the specified DDUF archive or folder. | ||
use_safetensors (`bool`, *optional*, defaults to `None`): | ||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the | ||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors | ||
|
@@ -1267,6 +1327,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |
use_onnx = kwargs.pop("use_onnx", None) | ||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) | ||
trust_remote_code = kwargs.pop("trust_remote_code", False) | ||
dduf = kwargs.pop("dduf", False) | ||
|
||
allow_pickle = False | ||
if use_safetensors is None: | ||
|
@@ -1346,6 +1407,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] | ||
# also allow downloading config.json files with the model | ||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] | ||
# also allow downloading the dduf | ||
# TODO: check that the file actually exist | ||
if dduf is not None: | ||
allow_patterns += [dduf] | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
allow_patterns += [ | ||
SCHEDULER_CONFIG_NAME, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned over Slack, why do we need two arguments here? How about just
dduf_filename
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will just use that then !
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that most, if not all, the components of the file name should be inferred. One of the reasons to use DDUF would be to provide confidence about the contents based on filename inspection, so we should try to make it work with something like an optional prefix, but build the rest automatically. So something like
dduf_name="my-lora"
resulting in the library saving to"my-lora-stable-diffusion-3-medium-diffusers-fp16.dduf"
, or something like that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok to use some kind of convention, but i don't think it should be an absolute requirement, ie. people can rename their files if they want to (but the metadata should be inside the file anyways)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that is the plan.