Skip to content

Commit 230c460

Browse files
authored
Fixes for FLUX (#258)
* fix for gpu * address comment * add TE context to sdxl
1 parent cc270b6 commit 230c460

File tree

4 files changed

+29
-4
lines changed

4 files changed

+29
-4
lines changed

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ enable_profiler: False
236236
# the iteration time a chance to stabilize.
237237
skip_first_n_steps_for_profiler: 5
238238
profiler_steps: 10
239+
profiler: ""
239240

240241
# Generation parameters
241242
prompt: "A magical castle in the middle of a forest, artistic drawing"
@@ -284,3 +285,5 @@ quantization: ''
284285
quantization_local_shard_count: -1
285286
use_qwix_quantization: False
286287
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
288+
289+
save_final_checkpoint: False

src/maxdiffusion/train_flux.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from maxdiffusion.train_utils import (
2424
validate_train_config,
25+
transformer_engine_context,
2526
)
2627

2728

@@ -39,6 +40,6 @@ def main(argv: Sequence[str]) -> None:
3940
max_logging.log(f"Found {jax.device_count()} devices.")
4041
train(config)
4142

42-
4343
if __name__ == "__main__":
44-
app.run(main)
44+
with transformer_engine_context():
45+
app.run(main)

src/maxdiffusion/train_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from maxdiffusion.train_utils import (
2929
validate_train_config,
30+
transformer_engine_context,
3031
)
3132

3233

@@ -51,4 +52,5 @@ def main(argv: Sequence[str]) -> None:
5152
tf.config.set_visible_devices([], "GPU")
5253
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
5354
torch.set_default_device("cpu")
54-
app.run(main)
55+
with transformer_engine_context():
56+
app.run(main)

src/maxdiffusion/train_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import queue
2121

2222
from maxdiffusion import max_utils, max_logging
23-
23+
from contextlib import contextmanager
2424

2525
def get_first_step(state):
2626
return int(state.step)
@@ -196,3 +196,22 @@ def generate_timestep_weights(config, num_timesteps):
196196
weights[bias_indices] *= timestep_bias_config["multiplier"]
197197
weights /= weights.sum()
198198
return jnp.array(weights)
199+
200+
201+
@contextmanager
202+
def transformer_engine_context():
203+
""" If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """
204+
try:
205+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
206+
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
207+
mesh_resource = MeshResource(
208+
dp_resource = "data",
209+
tp_resource = "tensor",
210+
fsdp_resource = "fsdp",
211+
pp_resource = None,
212+
cp_resource = None,
213+
)
214+
with global_shard_guard(mesh_resource):
215+
yield
216+
except ImportError:
217+
yield

0 commit comments

Comments
 (0)