Skip to content

Commit 9179302

Browse files
committed
Create training script for single GPU and set
model format to .safetensors
1 parent 3f0bdbb commit 9179302

File tree

1 file changed

+28
-7
lines changed
  • examples/research_projects/ip_adapter

1 file changed

+28
-7
lines changed

examples/research_projects/ip_adapter/README.md

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,25 @@ Certainly! Below is the documentation in pure Markdown format:
5151
The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.
5252

5353
#### Usage Example:
54+
55+
```
56+
accelerate launch --mixed_precision "fp16" \
57+
tutorial_train_ip-adapter.py \
58+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
59+
--image_encoder_path="{image_encoder_path}" \
60+
--data_json_file="{data.json}" \
61+
--data_root_path="{image_path}" \
62+
--mixed_precision="fp16" \
63+
--resolution=512 \
64+
--train_batch_size=8 \
65+
--dataloader_num_workers=4 \
66+
--learning_rate=1e-04 \
67+
--weight_decay=0.01 \
68+
--output_dir="{output_dir}" \
69+
--save_steps=10000
70+
```
71+
72+
### Multi-GPU Script:
5473
```
5574
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
5675
tutorial_train_ip-adapter.py \
@@ -92,25 +111,27 @@ The provided inference code is used to load a trained model checkpoint and extra
92111

93112
#### Usage Example:
94113
```python
95-
import torch
114+
from safetensors.torch import load_file, save_file
96115

97-
# Load the trained model checkpoint
98-
ckpt = "checkpoint-50000/pytorch_model.bin"
99-
sd = torch.load(ckpt, map_location="cpu")
116+
# Load the trained model checkpoint in safetensors format
117+
ckpt = "checkpoint-50000/pytorch_model.safetensors"
118+
sd = load_file(ckpt) # Using safetensors load function
100119

101120
# Extract image projection and IP adapter components
102121
image_proj_sd = {}
103122
ip_sd = {}
123+
104124
for k in sd:
105125
if k.startswith("unet"):
106-
pass
126+
pass # Skip unet-related keys
107127
elif k.startswith("image_proj_model"):
108128
image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
109129
elif k.startswith("adapter_modules"):
110130
ip_sd[k.replace("adapter_modules.", "")] = sd[k]
111131

112-
# Save the components into a binary file
113-
torch.save({"image_proj": image_proj_sd, "ip_adapter": ip_sd}, "ip_adapter.bin")
132+
# Save the components into separate safetensors files
133+
save_file(image_proj_sd, "image_proj.safetensors")
134+
save_file(ip_sd, "ip_adapter.safetensors")
114135
```
115136

116137
#### Parameters:

0 commit comments

Comments
 (0)