Skip to content

Commit 3beef2d

Browse files
author
Kyle Butler
committed
Add MPS Support
1 parent de7d22c commit 3beef2d

File tree

18 files changed

+465
-65
lines changed

18 files changed

+465
-65
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
---
2+
job: extension
3+
config:
4+
# this name will be the folder and filename name
5+
name: "my_first_flux_lora_v1"
6+
process:
7+
- type: 'sd_trainer'
8+
# root folder to save training sessions/samples/weights
9+
training_folder: "output"
10+
# uncomment to see performance stats in the terminal every N steps
11+
# performance_log_every: 1000
12+
device: mps
13+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
14+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15+
# trigger_word: "p3r5on"
16+
network:
17+
type: "lora"
18+
linear: 16
19+
linear_alpha: 16
20+
save:
21+
dtype: float16 # precision to save
22+
save_every: 250 # save every this many steps
23+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
24+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
25+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26+
# hf_repo_id: your-username/your-model-slug
27+
# hf_private: true #whether the repo is private or public
28+
datasets:
29+
# datasets are a folder of images. captions need to be txt files with the same name as the image
30+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31+
# images will automatically be resized and bucketed into the resolution specified
32+
# on windows, escape back slashes with another backslash so
33+
# "C:\\path\\to\\images\\folder"
34+
- folder_path: "/path/to/images/folder"
35+
caption_ext: "txt"
36+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37+
shuffle_tokens: false # shuffle caption order, split by commas
38+
cache_latents_to_disk: true # leave this true unless you know what you're doing
39+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40+
train:
41+
batch_size: 1
42+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
43+
gradient_accumulation_steps: 1
44+
train_unet: true
45+
train_text_encoder: false # probably won't work with flux
46+
gradient_checkpointing: true # need the on unless you have a ton of vram
47+
noise_scheduler: "flowmatch" # for training only
48+
optimizer: "adamw" # adamw8bit not supported on mps
49+
lr: 1e-4
50+
# uncomment this to skip the pre training sample
51+
# skip_first_sample: true
52+
# uncomment to completely disable sampling
53+
# disable_sampling: true
54+
# uncomment to use new vell curved weighting. Experimental but may produce better results
55+
# linear_timesteps: true
56+
57+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
58+
ema_config:
59+
use_ema: true
60+
ema_decay: 0.99
61+
62+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
63+
dtype: bf16
64+
model:
65+
# huggingface model name or path
66+
name_or_path: "black-forest-labs/FLUX.1-dev"
67+
is_flux: true
68+
quantize: false # 8-bit quantization backends are CUDA-only
69+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70+
sample:
71+
sampler: "flowmatch" # must match train.noise_scheduler
72+
sample_every: 250 # sample every this many steps
73+
width: 1024
74+
height: 1024
75+
prompts:
76+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
77+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78+
- "woman with red hair, playing chess at the park, bomb going off in the background"
79+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82+
- "a bear building a log cabin in the snow covered mountains"
83+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84+
- "hipster man with a beard, building a chair, in a wood shop"
85+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86+
- "a man holding a sign that says, 'this is a sign'"
87+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88+
neg: "" # not used on flux
89+
seed: 42
90+
walk_seed: true
91+
guidance_scale: 4
92+
sample_steps: 20
93+
# you can add any additional meta info here. [name] is replaced with config name at top
94+
meta:
95+
name: "[name]"
96+
version: '1.0'

flux_train_ui.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from PIL import Image
1111
import torch
1212
import uuid
13+
from toolkit import device_utils
1314
import os
1415
import shutil
1516
import json
@@ -98,7 +99,7 @@ def create_dataset(*inputs):
9899

99100
def run_captioning(images, concept_sentence, *captions):
100101
#Load internally to not consume resources for training
101-
device = "cuda" if torch.cuda.is_available() else "cpu"
102+
device = device_utils.get_device()
102103
torch_dtype = torch.float16
103104
model = AutoModelForCausalLM.from_pretrained(
104105
"multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
@@ -232,11 +233,18 @@ def start_training(
232233

233234
return f"Training completed successfully. Model saved as {slugged_lora_name}"
234235

235-
config_yaml = '''
236-
device: cuda:0
236+
default_device = str(device_utils.get_device())
237+
if default_device == "cuda":
238+
default_device = "cuda:0"
239+
240+
default_quantize = "false" if default_device == "mps" else "true"
241+
default_optimizer = "adamw" if default_device == "mps" else "adamw8bit"
242+
243+
config_yaml = f'''
244+
device: {default_device}
237245
model:
238246
is_flux: true
239-
quantize: true
247+
quantize: {default_quantize}
240248
network:
241249
linear: 16 #it will overcome the 'rank' parameter
242250
linear_alpha: 16 #you can have an alpha different than the ranking if you'd like
@@ -266,7 +274,7 @@ def start_training(
266274
gradient_accumulation_steps: 1
267275
gradient_checkpointing: true
268276
noise_scheduler: flowmatch
269-
optimizer: adamw8bit #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit
277+
optimizer: {default_optimizer} #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit
270278
train_text_encoder: false #probably doesn't work for flux
271279
train_unet: true
272280
'''
@@ -411,4 +419,4 @@ def start_training(
411419
do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
412420

413421
if __name__ == "__main__":
414-
demo.launch(share=True, show_error=True)
422+
demo.launch(share=True, show_error=True)

toolkit/config_modules.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torchaudio
88

9+
from toolkit import device_utils
910
from toolkit.prompt_utils import PromptEmbeds
1011

1112
ImgExt = Literal['jpg', 'png', 'webp']
@@ -953,6 +954,11 @@ def __init__(self, **kwargs):
953954

954955
self.num_workers: int = kwargs.get('num_workers', 2)
955956
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
957+
958+
if device_utils.is_mps_available():
959+
# Force num_workers to 0 on MPS to avoid shared memory issues
960+
self.num_workers = 0
961+
self.prefetch_factor = None
956962
self.extra_values: List[float] = kwargs.get('extra_values', [])
957963
self.square_crop: bool = kwargs.get('square_crop', False)
958964
# apply same augmentations to control images. Usually want this true unless special case
@@ -1354,5 +1360,3 @@ def validate_configs(
13541360

13551361
if train_config.diff_output_preservation and train_config.blank_prompt_preservation:
13561362
raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.")
1357-
1358-

toolkit/control_generator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tqdm import tqdm
99

1010
from torchvision import transforms
11+
from toolkit import device_utils
1112

1213
# supress all warnings
1314
import warnings
@@ -17,7 +18,7 @@
1718

1819

1920
def flush(garbage_collect=True):
20-
torch.cuda.empty_cache()
21+
device_utils.empty_cache()
2122
if garbage_collect:
2223
gc.collect()
2324

@@ -169,8 +170,9 @@ def _generate_control(self, img_path, control_type):
169170
0.229, 0.224, 0.225])
170171
])
171172

173+
# Assuming self.device is correct
172174
input_images = transform_image(img).unsqueeze(
173-
0).to('cuda').to(torch.float16)
175+
0).to(device).to(torch.float16)
174176

175177
# Prediction
176178
preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu()
@@ -259,7 +261,7 @@ def cleanup(self):
259261
for img_path in tqdm(img_list):
260262
for control in controls:
261263
start = time.time()
262-
control_gen = ControlGenerator(torch.device('cuda'))
264+
control_gen = ControlGenerator(device_utils.get_device())
263265
control_gen.debug = args.debug
264266
control_gen.regen = args.regen
265267
control_path = control_gen.get_control_path(img_path, control)

toolkit/custom_adapter.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,18 @@ def setup_adapter(self):
220220
elif self.adapter_type == 'llm_adapter':
221221
kwargs = {}
222222
if self.config.quantize_llm:
223-
bnb_kwargs = {
224-
'load_in_4bit': True,
225-
'bnb_4bit_quant_type': "nf4",
226-
'bnb_4bit_compute_dtype': torch.bfloat16
227-
}
228-
quantization_config = BitsAndBytesConfig(**bnb_kwargs)
229-
kwargs['quantization_config'] = quantization_config
223+
current_device = torch.device(self.device)
224+
if current_device.type == "mps":
225+
print("Warning: BitsAndBytes 4-bit quantization is not supported on MPS. Disabling quantization for LLM adapter.")
226+
self.config.quantize_llm = False
227+
else:
228+
bnb_kwargs = {
229+
'load_in_4bit': True,
230+
'bnb_4bit_quant_type': "nf4",
231+
'bnb_4bit_compute_dtype': torch.bfloat16
232+
}
233+
quantization_config = BitsAndBytesConfig(**bnb_kwargs)
234+
kwargs['quantization_config'] = quantization_config
230235
kwargs['torch_dtype'] = torch_dtype
231236
self.te = AutoModel.from_pretrained(
232237
self.config.text_encoder_path,
@@ -1386,4 +1391,4 @@ def post_weight_update(self):
13861391
# do any kind of updates after the weight update
13871392
if self.config.type == 'vision_direct':
13881393
self.vd_adapter.post_weight_update()
1389-
pass
1394+
pass

toolkit/device_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import gc
2+
from contextlib import nullcontext
3+
from typing import Optional, Union
4+
5+
import torch
6+
7+
8+
def _as_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
9+
if device is None:
10+
return get_device()
11+
if isinstance(device, torch.device):
12+
return device
13+
return torch.device(device)
14+
15+
16+
def get_device() -> torch.device:
17+
"""
18+
Returns the best available device.
19+
Prioritizes CUDA, then MPS, then CPU.
20+
"""
21+
if torch.cuda.is_available():
22+
return torch.device("cuda")
23+
elif torch.backends.mps.is_available():
24+
return torch.device("mps")
25+
else:
26+
return torch.device("cpu")
27+
28+
29+
def is_mps_available() -> bool:
30+
return torch.backends.mps.is_available()
31+
32+
33+
def is_cuda_available() -> bool:
34+
return torch.cuda.is_available()
35+
36+
37+
def empty_cache(device: Optional[Union[str, torch.device]] = None):
38+
"""
39+
Empties the cache for the selected device.
40+
"""
41+
target_device = _as_torch_device(device)
42+
gc.collect()
43+
if target_device.type == "cuda" and is_cuda_available():
44+
torch.cuda.empty_cache()
45+
elif target_device.type == "mps" and is_mps_available():
46+
torch.mps.empty_cache()
47+
48+
49+
def manual_seed(seed: int, device: Optional[Union[str, torch.device]] = None):
50+
"""
51+
Sets global seed and device-specific seed when supported.
52+
"""
53+
target_device = _as_torch_device(device)
54+
torch.manual_seed(seed)
55+
if target_device.type == "cuda" and is_cuda_available():
56+
torch.cuda.manual_seed(seed)
57+
elif target_device.type == "mps" and is_mps_available():
58+
torch.mps.manual_seed(seed)
59+
60+
61+
def get_device_name(device: Optional[Union[str, torch.device]] = None) -> str:
62+
return _as_torch_device(device).type
63+
64+
65+
def autocast(device: Optional[Union[str, torch.device]] = None):
66+
target_device = _as_torch_device(device)
67+
if target_device.type in {"cuda", "mps", "cpu"}:
68+
return torch.autocast(device_type=target_device.type)
69+
return nullcontext()

toolkit/losses.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from .llvae import LosslessLatentEncoder
3+
from toolkit import device_utils
34

45

56
def total_variation(image):
@@ -42,7 +43,7 @@ def forward(self, pred, target):
4243

4344
# Gradient penalty
4445
def get_gradient_penalty(critic, real, fake, device):
45-
with torch.autocast(device_type='cuda'):
46+
with device_utils.autocast(device):
4647
real = real.float()
4748
fake = fake.float()
4849
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
@@ -109,5 +110,3 @@ def separated_chan_loss(latent_chan):
109110
g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target))
110111
b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target))
111112
return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333
112-
113-

0 commit comments

Comments
 (0)