Skip to content

Commit 5c38c0d

Browse files
author
Kyle Butler
committed
Make MPS support additive and preserve CUDA defaults
1 parent 541815d commit 5c38c0d

File tree

11 files changed

+173
-80
lines changed

11 files changed

+173
-80
lines changed

config/examples/train_lora_flux_24gb.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ config:
99
training_folder: "output"
1010
# uncomment to see performance stats in the terminal every N steps
1111
# performance_log_every: 1000
12-
device: mps
12+
device: cuda:0
1313
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
1414
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
1515
# trigger_word: "p3r5on"
@@ -45,7 +45,7 @@ config:
4545
train_text_encoder: false # probably won't work with flux
4646
gradient_checkpointing: true # need the on unless you have a ton of vram
4747
noise_scheduler: "flowmatch" # for training only
48-
optimizer: "adamw" # adamw8bit not supported on mps
48+
optimizer: "adamw8bit"
4949
lr: 1e-4
5050
# uncomment this to skip the pre training sample
5151
# skip_first_sample: true
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
---
2+
job: extension
3+
config:
4+
# this name will be the folder and filename name
5+
name: "my_first_flux_lora_v1"
6+
process:
7+
- type: 'sd_trainer'
8+
# root folder to save training sessions/samples/weights
9+
training_folder: "output"
10+
# uncomment to see performance stats in the terminal every N steps
11+
# performance_log_every: 1000
12+
device: mps
13+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
14+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15+
# trigger_word: "p3r5on"
16+
network:
17+
type: "lora"
18+
linear: 16
19+
linear_alpha: 16
20+
save:
21+
dtype: float16 # precision to save
22+
save_every: 250 # save every this many steps
23+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
24+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
25+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26+
# hf_repo_id: your-username/your-model-slug
27+
# hf_private: true #whether the repo is private or public
28+
datasets:
29+
# datasets are a folder of images. captions need to be txt files with the same name as the image
30+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31+
# images will automatically be resized and bucketed into the resolution specified
32+
# on windows, escape back slashes with another backslash so
33+
# "C:\\path\\to\\images\\folder"
34+
- folder_path: "/path/to/images/folder"
35+
caption_ext: "txt"
36+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37+
shuffle_tokens: false # shuffle caption order, split by commas
38+
cache_latents_to_disk: true # leave this true unless you know what you're doing
39+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40+
train:
41+
batch_size: 1
42+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
43+
gradient_accumulation_steps: 1
44+
train_unet: true
45+
train_text_encoder: false # probably won't work with flux
46+
gradient_checkpointing: true # need the on unless you have a ton of vram
47+
noise_scheduler: "flowmatch" # for training only
48+
optimizer: "adamw" # adamw8bit not supported on mps
49+
lr: 1e-4
50+
# uncomment this to skip the pre training sample
51+
# skip_first_sample: true
52+
# uncomment to completely disable sampling
53+
# disable_sampling: true
54+
# uncomment to use new vell curved weighting. Experimental but may produce better results
55+
# linear_timesteps: true
56+
57+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
58+
ema_config:
59+
use_ema: true
60+
ema_decay: 0.99
61+
62+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
63+
dtype: bf16
64+
model:
65+
# huggingface model name or path
66+
name_or_path: "black-forest-labs/FLUX.1-dev"
67+
is_flux: true
68+
quantize: false # 8-bit quantization backends are CUDA-only
69+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70+
sample:
71+
sampler: "flowmatch" # must match train.noise_scheduler
72+
sample_every: 250 # sample every this many steps
73+
width: 1024
74+
height: 1024
75+
prompts:
76+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
77+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78+
- "woman with red hair, playing chess at the park, bomb going off in the background"
79+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82+
- "a bear building a log cabin in the snow covered mountains"
83+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84+
- "hipster man with a beard, building a chair, in a wood shop"
85+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86+
- "a man holding a sign that says, 'this is a sign'"
87+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88+
neg: "" # not used on flux
89+
seed: 42
90+
walk_seed: true
91+
guidance_scale: 4
92+
sample_steps: 20
93+
# you can add any additional meta info here. [name] is replaced with config name at top
94+
meta:
95+
name: "[name]"
96+
version: '1.0'

toolkit/custom_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
3232
from toolkit.prompt_utils import PromptEmbeds
3333
import weakref
34-
from toolkit import device_utils
3534

3635
if TYPE_CHECKING:
3736
from toolkit.stable_diffusion_model import StableDiffusion
@@ -221,7 +220,8 @@ def setup_adapter(self):
221220
elif self.adapter_type == 'llm_adapter':
222221
kwargs = {}
223222
if self.config.quantize_llm:
224-
if device_utils.is_mps_available():
223+
current_device = torch.device(self.device)
224+
if current_device.type == "mps":
225225
print("Warning: BitsAndBytes 4-bit quantization is not supported on MPS. Disabling quantization for LLM adapter.")
226226
self.config.quantize_llm = False
227227
else:

toolkit/device_utils.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,69 @@
1-
import torch
21
import gc
2+
from contextlib import nullcontext
3+
from typing import Optional, Union
4+
5+
import torch
6+
7+
8+
def _as_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
9+
if device is None:
10+
return get_device()
11+
if isinstance(device, torch.device):
12+
return device
13+
return torch.device(device)
14+
315

416
def get_device() -> torch.device:
517
"""
618
Returns the best available device.
7-
Prioritizes MPS on macOS, then CUDA, then CPU.
19+
Prioritizes CUDA, then MPS, then CPU.
820
"""
9-
if torch.backends.mps.is_available():
10-
return torch.device("mps")
11-
elif torch.cuda.is_available():
21+
if torch.cuda.is_available():
1222
return torch.device("cuda")
23+
elif torch.backends.mps.is_available():
24+
return torch.device("mps")
1325
else:
1426
return torch.device("cpu")
1527

28+
1629
def is_mps_available() -> bool:
1730
return torch.backends.mps.is_available()
1831

32+
1933
def is_cuda_available() -> bool:
2034
return torch.cuda.is_available()
2135

22-
def empty_cache():
36+
37+
def empty_cache(device: Optional[Union[str, torch.device]] = None):
2338
"""
24-
Empties the cache for the current device.
39+
Empties the cache for the selected device.
2540
"""
41+
target_device = _as_torch_device(device)
2642
gc.collect()
27-
if is_mps_available():
28-
torch.mps.empty_cache()
29-
elif is_cuda_available():
43+
if target_device.type == "cuda" and is_cuda_available():
3044
torch.cuda.empty_cache()
45+
elif target_device.type == "mps" and is_mps_available():
46+
torch.mps.empty_cache()
3147

32-
def manual_seed(seed: int):
48+
49+
def manual_seed(seed: int, device: Optional[Union[str, torch.device]] = None):
3350
"""
34-
Sets the seed for the current device.
51+
Sets global seed and device-specific seed when supported.
3552
"""
53+
target_device = _as_torch_device(device)
3654
torch.manual_seed(seed)
37-
if is_mps_available():
38-
torch.mps.manual_seed(seed)
39-
elif is_cuda_available():
55+
if target_device.type == "cuda" and is_cuda_available():
4056
torch.cuda.manual_seed(seed)
57+
elif target_device.type == "mps" and is_mps_available():
58+
torch.mps.manual_seed(seed)
4159

42-
def get_device_name() -> str:
43-
if is_mps_available():
44-
return "mps"
45-
elif is_cuda_available():
46-
return "cuda"
47-
else:
48-
return "cpu"
4960

50-
def autocast():
51-
if is_mps_available():
52-
return torch.autocast(device_type="mps")
53-
elif is_cuda_available():
54-
return torch.autocast(device_type="cuda")
55-
else:
56-
# Fallback to cpu or simple context manager
57-
return torch.autocast(device_type="cpu")
61+
def get_device_name(device: Optional[Union[str, torch.device]] = None) -> str:
62+
return _as_torch_device(device).type
63+
64+
65+
def autocast(device: Optional[Union[str, torch.device]] = None):
66+
target_device = _as_torch_device(device)
67+
if target_device.type in {"cuda", "mps", "cpu"}:
68+
return torch.autocast(device_type=target_device.type)
69+
return nullcontext()

toolkit/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def forward(self, pred, target):
4343

4444
# Gradient penalty
4545
def get_gradient_penalty(critic, real, fake, device):
46-
with device_utils.autocast():
46+
with device_utils.autocast(device):
4747
real = real.float()
4848
fake = fake.float()
4949
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()

toolkit/optimizer.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ def get_optimizer(
6161

6262
optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params)
6363
elif lower_type.endswith("8bit"):
64-
# Force fallback on MPS as bitsandbytes requires CUDA
6564
from toolkit import device_utils
66-
if device_utils.is_mps_available():
65+
if device_utils.get_device_name() == "mps":
6766
print("Bitsandbytes 8-bit optimizers are not supported on MPS. Falling back to standard optimizer.")
6867
if lower_type == "adam8bit":
6968
return torch.optim.Adam(params, lr=learning_rate, eps=1e-6, **optimizer_params)
@@ -79,33 +78,17 @@ def get_optimizer(
7978
# Fallback for ademamix or unknown - generic AdamW
8079
return torch.optim.AdamW(params, lr=learning_rate, eps=1e-6, **optimizer_params)
8180

82-
try:
83-
import bitsandbytes
84-
if lower_type == "adam8bit":
85-
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
86-
if lower_type == "ademamix8bit":
87-
return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
88-
elif lower_type == "adamw8bit":
89-
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
90-
elif lower_type == "lion8bit":
91-
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
92-
else:
93-
raise ValueError(f'Unknown optimizer type {optimizer_type}')
94-
except ImportError:
95-
print("Bitsandbytes not found or not supported. Falling back to standard optimizer.")
96-
if lower_type == "adam8bit":
97-
return torch.optim.Adam(params, lr=learning_rate, eps=1e-6, **optimizer_params)
98-
elif lower_type == "adamw8bit":
99-
return torch.optim.AdamW(params, lr=learning_rate, eps=1e-6, **optimizer_params)
100-
elif lower_type == "lion8bit":
101-
try:
102-
from lion_pytorch import Lion
103-
return Lion(params, lr=learning_rate, **optimizer_params)
104-
except ImportError:
105-
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
106-
else:
107-
# Fallback for ademamix or unknown - generic AdamW
108-
return torch.optim.AdamW(params, lr=learning_rate, eps=1e-6, **optimizer_params)
81+
import bitsandbytes
82+
if lower_type == "adam8bit":
83+
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
84+
if lower_type == "ademamix8bit":
85+
return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
86+
elif lower_type == "adamw8bit":
87+
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
88+
elif lower_type == "lion8bit":
89+
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
90+
else:
91+
raise ValueError(f'Unknown optimizer type {optimizer_type}')
10992
elif lower_type == 'adam':
11093
optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
11194
elif lower_type == 'adamw':

toolkit/stable_diffusion_model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from toolkit.ip_adapter import IPAdapter
2929
from toolkit.util.vae import load_vae
3030
from toolkit import train_tools
31-
from toolkit import device_utils
3231
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
3332
from toolkit.metadata import get_meta_for_safetensors
3433
from toolkit.models.decorator import Decorator
@@ -479,7 +478,7 @@ def load_model(self):
479478
te_kwargs = {}
480479
# handle quantization of TE
481480
te_is_quantized = False
482-
if device_utils.is_mps_available() and self.model_config.text_encoder_bits in [4, 8]:
481+
if self.device_torch.type == "mps" and self.model_config.text_encoder_bits in [4, 8]:
483482
print_acc("Warning: 4/8-bit quantization is not supported on MPS. Ignoring quantization.")
484483
else:
485484
if self.model_config.text_encoder_bits == 8:
@@ -568,7 +567,7 @@ def load_model(self):
568567
te_kwargs = {}
569568
# handle quantization of TE
570569
te_is_quantized = False
571-
if device_utils.is_mps_available() and self.model_config.text_encoder_bits in [4, 8]:
570+
if self.device_torch.type == "mps" and self.model_config.text_encoder_bits in [4, 8]:
572571
print_acc("Warning: 4/8-bit quantization is not supported on MPS. Ignoring quantization.")
573572
else:
574573
if self.model_config.text_encoder_bits == 8:
@@ -951,7 +950,7 @@ def load_model(self):
951950
te_kwargs = {}
952951
# handle quantization of TE
953952
te_is_quantized = False
954-
if device_utils.is_mps_available() and self.model_config.text_encoder_bits in [4, 8]:
953+
if self.device_torch.type == "mps" and self.model_config.text_encoder_bits in [4, 8]:
955954
print_acc("Warning: 4/8-bit quantization is not supported on MPS. Ignoring quantization.")
956955
else:
957956
if self.model_config.text_encoder_bits == 8:

toolkit/train_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,8 @@ class LearnableSNRGamma:
658658
This is a trainer for learnable snr gamma
659659
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
660660
"""
661-
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device=None):
662-
self.device = device if device is not None else device_utils.get_device()
661+
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
662+
self.device = device
663663
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
664664
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=self.device))
665665
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=self.device))

toolkit/unloader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import gc
22
import torch
33
from toolkit.basic import flush
4-
from toolkit.device_utils import is_mps_available
54
from typing import TYPE_CHECKING
65

76

@@ -40,6 +39,8 @@ def unload_text_encoder(model: "BaseModel"):
4039
# we need to make it appear as a text encoder module without actually having one so all
4140
# to functions and what not will work.
4241

42+
is_mps = isinstance(model.device_torch, torch.device) and model.device_torch.type == "mps"
43+
4344
if model.text_encoder is not None:
4445
if isinstance(model.text_encoder, list):
4546
text_encoder_list = []
@@ -51,7 +52,7 @@ def unload_text_encoder(model: "BaseModel"):
5152
text_encoder_list.append(te)
5253
# if we are on mps, we don't want to move to cpu because it's unified memory
5354
# and just freeing the reference is enough and faster
54-
if not is_mps_available():
55+
if not is_mps:
5556
pipe.text_encoder.to('cpu')
5657
else:
5758
pipe.text_encoder.to('meta')
@@ -61,18 +62,18 @@ def unload_text_encoder(model: "BaseModel"):
6162
while hasattr(pipe, f"text_encoder_{i}"):
6263
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
6364
text_encoder_list.append(te)
64-
if is_mps_available():
65+
if is_mps:
6566
getattr(pipe, f"text_encoder_{i}").to('meta')
6667
setattr(pipe, f"text_encoder_{i}", te)
6768
i += 1
6869
model.text_encoder = text_encoder_list
6970
else:
7071
# only has a single text encoder
71-
if is_mps_available():
72+
if is_mps:
7273
model.text_encoder.to('meta')
7374
model.text_encoder = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
7475

75-
if torch.backends.mps.is_available():
76+
if is_mps:
7677
gc.collect()
7778
torch.mps.empty_cache()
7879

0 commit comments

Comments
 (0)