Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions config/examples/train_lora_flux_mps.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
---
job: extension
config:
# this name will be the folder and filename name
name: "my_first_flux_lora_v1"
process:
- type: 'sd_trainer'
# root folder to save training sessions/samples/weights
training_folder: "output"
# uncomment to see performance stats in the terminal every N steps
# performance_log_every: 1000
device: mps
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
# trigger_word: "p3r5on"
network:
type: "lora"
linear: 16
linear_alpha: 16
save:
dtype: float16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
# on windows, escape back slashes with another backslash so
# "C:\\path\\to\\images\\folder"
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
shuffle_tokens: false # shuffle caption order, split by commas
cache_latents_to_disk: true # leave this true unless you know what you're doing
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
train:
batch_size: 1
steps: 2000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation_steps: 1
train_unet: true
train_text_encoder: false # probably won't work with flux
gradient_checkpointing: true # need the on unless you have a ton of vram
noise_scheduler: "flowmatch" # for training only
optimizer: "adamw" # adamw8bit not supported on mps
lr: 1e-4
# uncomment this to skip the pre training sample
# skip_first_sample: true
# uncomment to completely disable sampling
# disable_sampling: true
# uncomment to use new vell curved weighting. Experimental but may produce better results
# linear_timesteps: true

# ema will smooth out learning, but could slow it down. Recommended to leave on.
ema_config:
use_ema: true
ema_decay: 0.99

# will probably need this if gpu supports it for flux, other dtypes may not work correctly
dtype: bf16
model:
# huggingface model name or path
name_or_path: "black-forest-labs/FLUX.1-dev"
is_flux: true
quantize: false # 8-bit quantization backends are CUDA-only
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
width: 1024
height: 1024
prompts:
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
- "a bear building a log cabin in the snow covered mountains"
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
- "hipster man with a beard, building a chair, in a wood shop"
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
- "a man holding a sign that says, 'this is a sign'"
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
neg: "" # not used on flux
seed: 42
walk_seed: true
guidance_scale: 4
sample_steps: 20
# you can add any additional meta info here. [name] is replaced with config name at top
meta:
name: "[name]"
version: '1.0'
20 changes: 14 additions & 6 deletions flux_train_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PIL import Image
import torch
import uuid
from toolkit import device_utils
import os
import shutil
import json
Expand Down Expand Up @@ -98,7 +99,7 @@ def create_dataset(*inputs):

def run_captioning(images, concept_sentence, *captions):
#Load internally to not consume resources for training
device = "cuda" if torch.cuda.is_available() else "cpu"
device = device_utils.get_device()
torch_dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(
"multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
Expand Down Expand Up @@ -232,11 +233,18 @@ def start_training(

return f"Training completed successfully. Model saved as {slugged_lora_name}"

config_yaml = '''
device: cuda:0
default_device = str(device_utils.get_device())
if default_device == "cuda":
default_device = "cuda:0"

default_quantize = "false" if default_device == "mps" else "true"
default_optimizer = "adamw" if default_device == "mps" else "adamw8bit"

config_yaml = f'''
device: {default_device}
model:
is_flux: true
quantize: true
quantize: {default_quantize}
network:
linear: 16 #it will overcome the 'rank' parameter
linear_alpha: 16 #you can have an alpha different than the ranking if you'd like
Expand Down Expand Up @@ -266,7 +274,7 @@ def start_training(
gradient_accumulation_steps: 1
gradient_checkpointing: true
noise_scheduler: flowmatch
optimizer: adamw8bit #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit
optimizer: {default_optimizer} #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit
train_text_encoder: false #probably doesn't work for flux
train_unet: true
'''
Expand Down Expand Up @@ -411,4 +419,4 @@ def start_training(
do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)

if __name__ == "__main__":
demo.launch(share=True, show_error=True)
demo.launch(share=True, show_error=True)
8 changes: 6 additions & 2 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torchaudio

from toolkit import device_utils
from toolkit.prompt_utils import PromptEmbeds

ImgExt = Literal['jpg', 'png', 'webp']
Expand Down Expand Up @@ -953,6 +954,11 @@ def __init__(self, **kwargs):

self.num_workers: int = kwargs.get('num_workers', 2)
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)

if device_utils.is_mps_available():
# Force num_workers to 0 on MPS to avoid shared memory issues
self.num_workers = 0
self.prefetch_factor = None
self.extra_values: List[float] = kwargs.get('extra_values', [])
self.square_crop: bool = kwargs.get('square_crop', False)
# apply same augmentations to control images. Usually want this true unless special case
Expand Down Expand Up @@ -1354,5 +1360,3 @@ def validate_configs(

if train_config.diff_output_preservation and train_config.blank_prompt_preservation:
raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.")


8 changes: 5 additions & 3 deletions toolkit/control_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tqdm import tqdm

from torchvision import transforms
from toolkit import device_utils

# supress all warnings
import warnings
Expand All @@ -17,7 +18,7 @@


def flush(garbage_collect=True):
torch.cuda.empty_cache()
device_utils.empty_cache()
if garbage_collect:
gc.collect()

Expand Down Expand Up @@ -169,8 +170,9 @@ def _generate_control(self, img_path, control_type):
0.229, 0.224, 0.225])
])

# Assuming self.device is correct
input_images = transform_image(img).unsqueeze(
0).to('cuda').to(torch.float16)
0).to(device).to(torch.float16)

# Prediction
preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu()
Expand Down Expand Up @@ -259,7 +261,7 @@ def cleanup(self):
for img_path in tqdm(img_list):
for control in controls:
start = time.time()
control_gen = ControlGenerator(torch.device('cuda'))
control_gen = ControlGenerator(device_utils.get_device())
control_gen.debug = args.debug
control_gen.regen = args.regen
control_path = control_gen.get_control_path(img_path, control)
Expand Down
21 changes: 13 additions & 8 deletions toolkit/custom_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,18 @@ def setup_adapter(self):
elif self.adapter_type == 'llm_adapter':
kwargs = {}
if self.config.quantize_llm:
bnb_kwargs = {
'load_in_4bit': True,
'bnb_4bit_quant_type': "nf4",
'bnb_4bit_compute_dtype': torch.bfloat16
}
quantization_config = BitsAndBytesConfig(**bnb_kwargs)
kwargs['quantization_config'] = quantization_config
current_device = torch.device(self.device)
if current_device.type == "mps":
print("Warning: BitsAndBytes 4-bit quantization is not supported on MPS. Disabling quantization for LLM adapter.")
self.config.quantize_llm = False
else:
bnb_kwargs = {
'load_in_4bit': True,
'bnb_4bit_quant_type': "nf4",
'bnb_4bit_compute_dtype': torch.bfloat16
}
quantization_config = BitsAndBytesConfig(**bnb_kwargs)
kwargs['quantization_config'] = quantization_config
kwargs['torch_dtype'] = torch_dtype
self.te = AutoModel.from_pretrained(
self.config.text_encoder_path,
Expand Down Expand Up @@ -1386,4 +1391,4 @@ def post_weight_update(self):
# do any kind of updates after the weight update
if self.config.type == 'vision_direct':
self.vd_adapter.post_weight_update()
pass
pass
69 changes: 69 additions & 0 deletions toolkit/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import gc
from contextlib import nullcontext
from typing import Optional, Union

import torch


def _as_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
if device is None:
return get_device()
if isinstance(device, torch.device):
return device
return torch.device(device)


def get_device() -> torch.device:
"""
Returns the best available device.
Prioritizes CUDA, then MPS, then CPU.
"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")


def is_mps_available() -> bool:
return torch.backends.mps.is_available()


def is_cuda_available() -> bool:
return torch.cuda.is_available()


def empty_cache(device: Optional[Union[str, torch.device]] = None):
"""
Empties the cache for the selected device.
"""
target_device = _as_torch_device(device)
gc.collect()
if target_device.type == "cuda" and is_cuda_available():
torch.cuda.empty_cache()
elif target_device.type == "mps" and is_mps_available():
torch.mps.empty_cache()


def manual_seed(seed: int, device: Optional[Union[str, torch.device]] = None):
"""
Sets global seed and device-specific seed when supported.
"""
target_device = _as_torch_device(device)
torch.manual_seed(seed)
if target_device.type == "cuda" and is_cuda_available():
torch.cuda.manual_seed(seed)
elif target_device.type == "mps" and is_mps_available():
torch.mps.manual_seed(seed)


def get_device_name(device: Optional[Union[str, torch.device]] = None) -> str:
return _as_torch_device(device).type


def autocast(device: Optional[Union[str, torch.device]] = None):
target_device = _as_torch_device(device)
if target_device.type in {"cuda", "mps", "cpu"}:
return torch.autocast(device_type=target_device.type)
return nullcontext()
5 changes: 2 additions & 3 deletions toolkit/losses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from .llvae import LosslessLatentEncoder
from toolkit import device_utils


def total_variation(image):
Expand Down Expand Up @@ -42,7 +43,7 @@ def forward(self, pred, target):

# Gradient penalty
def get_gradient_penalty(critic, real, fake, device):
with torch.autocast(device_type='cuda'):
with device_utils.autocast(device):
real = real.float()
fake = fake.float()
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
Expand Down Expand Up @@ -109,5 +110,3 @@ def separated_chan_loss(latent_chan):
g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target))
b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target))
return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333


Loading