Skip to content

Commit b951454

Browse files
ksikiricjfacevedo-googleentrpn
authored
Training added on top of flux_impl (#147)
* Added training code, loss and results are stable * Rebased on flux_lora and aligned flux_pipeline with changes in generate_flux.py * batch text encoding. * comment out post_training_steps * refactor some code for similarity to sd trainers. * Added orbax saving and a new file for inference that utilizes the pipeline. * Update generate_flux_pipeline.py * Fixed comments and rebased on main * ruff + code_style --------- Co-authored-by: Juan Acevedo <[email protected]> Co-authored-by: Juan Acevedo <[email protected]>
1 parent da779ea commit b951454

File tree

12 files changed

+1430
-28
lines changed

12 files changed

+1430
-28
lines changed

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
3535
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
36+
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
3637

3738

3839
def create_orbax_checkpoint_manager(
@@ -56,17 +57,20 @@ def create_orbax_checkpoint_manager(
5657
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
5758
p = epath.Path(checkpoint_dir)
5859

59-
item_names = (
60-
"unet_config",
61-
"vae_config",
62-
"text_encoder_config",
63-
"scheduler_config",
64-
"unet_state",
65-
"vae_state",
66-
"text_encoder_state",
67-
"tokenizer_config",
68-
)
69-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
60+
if checkpoint_type == FLUX_CHECKPOINT:
61+
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
62+
else:
63+
item_names = (
64+
"unet_config",
65+
"vae_config",
66+
"text_encoder_config",
67+
"scheduler_config",
68+
"unet_state",
69+
"vae_state",
70+
"text_encoder_state",
71+
"tokenizer_config",
72+
)
73+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
7074
item_names += (
7175
"text_encoder_2_state",
7276
"text_encoder_2_config",
@@ -117,7 +121,7 @@ def load_stable_diffusion_configs(
117121
"tokenizer_config": orbax.checkpoint.args.JsonRestore(),
118122
}
119123

120-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
124+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
121125
restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore()
122126

123127
return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None)
@@ -139,6 +143,7 @@ def load_params_from_path(
139143

140144
ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item)
141145
ckpt_path = epath.Path(ckpt_path)
146+
ckpt_path = os.path.abspath(ckpt_path)
142147

143148
restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
144149
restored = ckptr.restore(
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
from contextlib import nullcontext
19+
import functools
20+
import json
21+
import jax
22+
from jax.sharding import Mesh
23+
import orbax.checkpoint as ocp
24+
import grain.python as grain
25+
from maxdiffusion import (
26+
max_utils,
27+
FlaxAutoencoderKL,
28+
max_logging,
29+
)
30+
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
31+
from ..pipelines.flux.flux_pipeline import FluxPipeline
32+
33+
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer)
34+
35+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
36+
from maxdiffusion.models.flux.util import load_flow_model
37+
38+
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
39+
_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX"
40+
41+
FLUX_STATE_KEY = "flux_state"
42+
FLUX_TRANSFORMER_PARAMS_KEY = "flux_transformer_params"
43+
FLUX_STATE_SHARDINGS_KEY = "flux_state_shardings"
44+
FLUX_VAE_PARAMS_KEY = "flux_vae"
45+
VAE_STATE_KEY = "vae_state"
46+
VAE_STATE_SHARDINGS_KEY = "vae_state_shardings"
47+
48+
49+
class FluxCheckpointer(ABC):
50+
51+
def __init__(self, config, checkpoint_type):
52+
self.config = config
53+
self.checkpoint_type = checkpoint_type
54+
self.checkpoint_format = None
55+
56+
self.rng = jax.random.PRNGKey(self.config.seed)
57+
self.devices_array = max_utils.create_device_mesh(config)
58+
self.mesh = Mesh(self.devices_array, self.config.mesh_axes)
59+
self.total_train_batch_size = self.config.total_train_batch_size
60+
61+
self.checkpoint_manager = create_orbax_checkpoint_manager(
62+
self.config.checkpoint_dir,
63+
enable_checkpointing=True,
64+
save_interval_steps=1,
65+
checkpoint_type=checkpoint_type,
66+
dataset_type=config.dataset_type,
67+
)
68+
69+
def _create_optimizer(self, config, learning_rate):
70+
71+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
72+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
73+
)
74+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
75+
return tx, learning_rate_scheduler
76+
77+
def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training):
78+
transformer = pipeline.flux
79+
80+
tx, learning_rate_scheduler = None, None
81+
if is_training:
82+
learning_rate = self.config.learning_rate
83+
84+
tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate)
85+
86+
transformer_eval_params = transformer.init_weights(
87+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
88+
)
89+
90+
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
91+
92+
weights_init_fn = functools.partial(
93+
pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length
94+
)
95+
flux_state, state_mesh_shardings = max_utils.setup_initial_state(
96+
model=pipeline.flux,
97+
tx=tx,
98+
config=self.config,
99+
mesh=self.mesh,
100+
weights_init_fn=weights_init_fn,
101+
model_params=None,
102+
checkpoint_manager=self.checkpoint_manager,
103+
checkpoint_item=checkpoint_item_name,
104+
training=is_training,
105+
)
106+
if not self.config.train_new_flux:
107+
flux_state = flux_state.replace(params=transformer_params)
108+
flux_state = jax.device_put(flux_state, state_mesh_shardings)
109+
return flux_state, state_mesh_shardings, learning_rate_scheduler
110+
111+
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
112+
113+
# Currently VAE training is not supported.
114+
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng)
115+
return max_utils.setup_initial_state(
116+
model=pipeline.vae,
117+
tx=None,
118+
config=self.config,
119+
mesh=self.mesh,
120+
weights_init_fn=weights_init_fn,
121+
model_params=params.get("flux_vae", None),
122+
checkpoint_manager=self.checkpoint_manager,
123+
checkpoint_item=checkpoint_item_name,
124+
training=is_training,
125+
)
126+
127+
def restore_data_iterator_state(self, data_iterator):
128+
if (
129+
self.config.dataset_type == "grain"
130+
and data_iterator is not None
131+
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
132+
):
133+
max_logging.log("Restoring data iterator from checkpoint")
134+
restored = self.checkpoint_manager.restore(
135+
self.checkpoint_manager.latest_step(),
136+
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
137+
)
138+
data_iterator.local_iterator = restored["iter"]
139+
else:
140+
max_logging.log("data iterator checkpoint not found")
141+
return data_iterator
142+
143+
def _get_pipeline_class(self):
144+
return FluxPipeline
145+
146+
def _set_checkpoint_format(self, checkpoint_format):
147+
self.checkpoint_format = checkpoint_format
148+
149+
def save_checkpoint(self, train_step, pipeline, train_states):
150+
def config_to_json(model_or_config):
151+
return json.loads(model_or_config.to_json_string())
152+
153+
items = {
154+
"flux_config": ocp.args.JsonSave(config_to_json(pipeline.flux)),
155+
"vae_config": ocp.args.JsonSave(config_to_json(pipeline.vae)),
156+
"scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)),
157+
}
158+
159+
items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY])
160+
items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"])
161+
items["scheduler"] = ocp.args.PyTreeSave(train_states["scheduler"])
162+
163+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
164+
165+
def load_params(self, step=None):
166+
167+
self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
168+
169+
def load_flux_configs_from_orbax(self, step):
170+
max_logging.log("Restoring stable diffusion configs")
171+
if step is None:
172+
step = self.checkpoint_manager.latest_step()
173+
if step is None:
174+
return None
175+
176+
restore_args = {
177+
"flux_config": ocp.args.JsonRestore(),
178+
"vae_config": ocp.args.JsonRestore(),
179+
"scheduler_config": ocp.args.JsonRestore(),
180+
}
181+
182+
return (self.checkpoint_manager.restore(step, args=ocp.args.Composite(**restore_args)), None)
183+
184+
def load_diffusers_checkpoint(self):
185+
flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
186+
187+
if jax.device_count() == jax.local_device_count():
188+
context = jax.default_device(jax.devices("cpu")[0])
189+
else:
190+
context = nullcontext()
191+
192+
with context:
193+
clip_encoder = FlaxCLIPTextModel.from_pretrained(self.config.clip_model_name_or_path, dtype=self.config.weights_dtype)
194+
clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True)
195+
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
196+
t5_tokenizer = AutoTokenizer.from_pretrained(
197+
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
198+
)
199+
200+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
201+
self.config.pretrained_model_name_or_path,
202+
subfolder="vae",
203+
from_pt=True,
204+
use_safetensors=True,
205+
dtype=self.config.weights_dtype,
206+
)
207+
208+
# loading from pretrained here causes a crash when trying to compile the model
209+
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
210+
transformer = FluxTransformer2DModel.from_config(
211+
self.config.pretrained_model_name_or_path,
212+
subfolder="transformer",
213+
mesh=self.mesh,
214+
split_head_dim=self.config.split_head_dim,
215+
attention_kernel=self.config.attention,
216+
flash_block_sizes=flash_block_sizes,
217+
dtype=self.config.activations_dtype,
218+
weights_dtype=self.config.weights_dtype,
219+
precision=max_utils.get_precision(self.config),
220+
)
221+
transformer_eval_params = transformer.init_weights(
222+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
223+
)
224+
225+
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
226+
227+
pipeline = FluxPipeline(
228+
t5_encoder,
229+
clip_encoder,
230+
vae,
231+
t5_tokenizer,
232+
clip_tokenizer,
233+
transformer,
234+
None,
235+
dtype=self.config.activations_dtype,
236+
mesh=self.mesh,
237+
config=self.config,
238+
rng=self.rng,
239+
)
240+
241+
params = {FLUX_VAE_PARAMS_KEY: vae_params, FLUX_TRANSFORMER_PARAMS_KEY: transformer_params}
242+
243+
return pipeline, params
244+
245+
def load_checkpoint(self, step=None, scheduler_class=None):
246+
247+
model_configs = self.load_flux_configs_from_orbax(step)
248+
249+
pipeline, params = None, {}
250+
251+
if model_configs:
252+
if jax.device_count() == jax.local_device_count():
253+
context = jax.default_device(jax.devices("cpu")[0])
254+
else:
255+
context = nullcontext()
256+
257+
with context:
258+
clip_encoder = FlaxCLIPTextModel.from_pretrained(
259+
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
260+
)
261+
clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True)
262+
t5_encoder = FlaxT5EncoderModel.from_pretrained(
263+
self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype
264+
)
265+
t5_tokenizer = AutoTokenizer.from_pretrained(
266+
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
267+
)
268+
269+
vae = FlaxAutoencoderKL.from_config(
270+
model_configs[0]["vae_config"],
271+
dtype=self.config.activations_dtype,
272+
weights_dtype=self.config.weights_dtype,
273+
from_pt=self.config.from_pt,
274+
)
275+
276+
transformer = FluxTransformer2DModel.from_config(
277+
model_configs[0]["flux_config"],
278+
mesh=self.mesh,
279+
split_head_dim=self.config.split_head_dim,
280+
attention_kernel=self.config.attention,
281+
flash_block_sizes=max_utils.get_flash_block_sizes(self.config),
282+
dtype=self.config.activations_dtype,
283+
weights_dtype=self.config.weights_dtype,
284+
precision=max_utils.get_precision(self.config),
285+
from_pt=self.config.from_pt,
286+
)
287+
288+
pipeline = FluxPipeline(
289+
t5_encoder,
290+
clip_encoder,
291+
vae,
292+
t5_tokenizer,
293+
clip_tokenizer,
294+
transformer,
295+
None,
296+
dtype=self.config.activations_dtype,
297+
mesh=self.mesh,
298+
config=self.config,
299+
rng=self.rng,
300+
)
301+
302+
else:
303+
pipeline, params = self.load_diffusers_checkpoint()
304+
305+
return pipeline, params

0 commit comments

Comments
 (0)