Skip to content

Commit 0ba7bd4

Browse files
committed
add flax back.
1 parent 522e70d commit 0ba7bd4

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

setup.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
you need to go back to main before executing this.
8585
"""
8686

87+
import os
8788
import re
8889
import sys
8990

@@ -134,6 +135,7 @@
134135
"requests",
135136
"tensorboard",
136137
"tiktoken>=0.7.0",
138+
"flax>=0.4.1",
137139
"torch>=1.4",
138140
"torchvision",
139141
"transformers>=4.41.2",
@@ -243,6 +245,11 @@ def run(self):
243245
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
244246
extras["torchao"] = deps_list("torchao", "accelerate")
245247

248+
if os.name == "nt": # windows
249+
extras["flax"] = [] # jax is not supported on windows
250+
else:
251+
extras["flax"] = deps_list("jax", "jaxlib", "flax")
252+
246253

247254
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"]
248255

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"requests": "requests",
4343
"tensorboard": "tensorboard",
4444
"tiktoken": "tiktoken>=0.7.0",
45+
"flax": "flax>=0.4.1",
4546
"torch": "torch>=1.4",
4647
"torchvision": "torchvision",
4748
"transformers": "transformers>=4.41.2",

0 commit comments

Comments
 (0)