diff --git a/setup.py b/setup.py index ba3ad8e2b307..cb1649f35451 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,6 @@ "compel==0.1.8", "datasets", "filelock", - "flax>=0.4.1", "hf-doc-builder>=0.3.0", "huggingface-hub>=0.34.0", "requests-mock==1.10.0", @@ -137,6 +136,7 @@ "requests", "tensorboard", "tiktoken>=0.7.0", + "flax>=0.4.1", "torch>=1.4", "torchvision", "transformers>=4.41.2", @@ -252,6 +252,7 @@ def run(self): else: extras["flax"] = deps_list("jax", "jaxlib", "flax") + extras["dev"] = ( extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] ) @@ -265,6 +266,7 @@ def run(self): deps["requests"], deps["safetensors"], deps["Pillow"], + deps["torch"], ] version_range_max = max(sys.version_info[1], 10) + 1 diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 79dc4c50a050..367ff89b9a42 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -7,7 +7,6 @@ "compel": "compel==0.1.8", "datasets": "datasets", "filelock": "filelock", - "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", "huggingface-hub": "huggingface-hub>=0.34.0", "requests-mock": "requests-mock==1.10.0", @@ -44,6 +43,7 @@ "requests": "requests", "tensorboard": "tensorboard", "tiktoken": "tiktoken>=0.7.0", + "flax": "flax>=0.4.1", "torch": "torch>=1.4", "torchvision": "torchvision", "transformers": "transformers>=4.41.2",