Skip to content

Commit f68692a

Browse files
authored
Skip jax distributed initialize if specified to skip (#172)
1 parent 6d54a84 commit f68692a

File tree

11 files changed

+13
-3
lines changed

11 files changed

+13
-3
lines changed

.github/workflows/UploadDockerImages.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
run: docker system prune --all --force
4545
- name: build maxdiffusion jax stable stack gpu image
4646
run: |
47-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
47+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu
4848
- name: build maxdiffusion jax nightly image
4949
run: |
5050
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu

src/maxdiffusion/configs/base14.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ diffusion_scheduler_config: {
9292

9393
# Hardware
9494
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
95+
skip_jax_distributed_system: False
9596

9697
base_output_directory: ""
9798

src/maxdiffusion/configs/base21.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ diffusion_scheduler_config: {
9191

9292
# Hardware
9393
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
94+
skip_jax_distributed_system: False
9495

9596
# Output directory
9697
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ diffusion_scheduler_config: {
104104

105105
# Hardware
106106
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
107+
skip_jax_distributed_system: False
107108

108109
# Output directory
109110
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ base_output_directory: ""
120120

121121
# Hardware
122122
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
123+
skip_jax_distributed_system: False
123124

124125
# Parallelism
125126
mesh_axes: ['data', 'fsdp', 'tensor']

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ base_output_directory: ""
120120

121121
# Hardware
122122
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
123+
skip_jax_distributed_system: False
123124

124125
# Parallelism
125126
mesh_axes: ['data', 'fsdp', 'tensor']

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ base_output_directory: ""
128128

129129
# Hardware
130130
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
131+
skip_jax_distributed_system: False
131132

132133
# Parallelism
133134
mesh_axes: ['data', 'fsdp', 'tensor']

src/maxdiffusion/configs/base_wan_t2v.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ base_output_directory: ""
114114

115115
# Hardware
116116
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
117+
skip_jax_distributed_system: False
117118

118119
# Parallelism
119120
mesh_axes: ['data', 'fsdp', 'tensor']

src/maxdiffusion/configs/base_xl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ base_output_directory: ""
9595

9696
# Hardware
9797
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
98-
98+
skip_jax_distributed_system: False
9999
# Parallelism
100100
mesh_axes: ['data', 'fsdp', 'tensor']
101101

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ diffusion_scheduler_config: {
7171

7272
# Hardware
7373
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
74-
74+
skip_jax_distributed_system: False
7575
# Output directory
7676
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
7777
base_output_directory: ""

0 commit comments

Comments
 (0)