Skip to content

Commit 158e1f2

Browse files
authored
Fix WAN training/inference error for new JAX/Flax Version (#265)
1 parent 972b4ff commit 158e1f2

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from absl import app
2222
from maxdiffusion.utils import export_to_video
2323
from google.cloud import storage
24+
import flax
2425

2526

2627
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -161,6 +162,7 @@ def run(config, pipeline=None, filename_prefix=""):
161162

162163
def main(argv: Sequence[str]) -> None:
163164
pyconfig.initialize(argv)
165+
flax.config.update('flax_always_shard_variable', False)
164166
run(pyconfig.config)
165167

166168

src/maxdiffusion/train_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import app
2121
from maxdiffusion import max_logging, pyconfig
2222
from maxdiffusion.train_utils import validate_train_config
23+
import flax
2324

2425

2526
def train(config):
@@ -34,6 +35,7 @@ def main(argv: Sequence[str]) -> None:
3435
config = pyconfig.config
3536
validate_train_config(config)
3637
max_logging.log(f"Found {jax.device_count()} devices.")
38+
flax.config.update('flax_always_shard_variable', False)
3739
train(config)
3840

3941

0 commit comments

Comments
 (0)