Skip to content

Commit 1a917d3

Browse files
committed
Revert "merge main"
This reverts commit 65efbce.
1 parent 65efbce commit 1a917d3

File tree

215 files changed

+3675
-1671
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

215 files changed

+3675
-1671
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: Run Flax dependency tests
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
paths:
8+
- "src/diffusers/**.py"
9+
push:
10+
branches:
11+
- main
12+
13+
concurrency:
14+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
15+
cancel-in-progress: true
16+
17+
jobs:
18+
check_flax_dependencies:
19+
runs-on: ubuntu-22.04
20+
steps:
21+
- uses: actions/checkout@v3
22+
- name: Set up Python
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: "3.8"
26+
- name: Install dependencies
27+
run: |
28+
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
29+
python -m pip install --upgrade pip uv
30+
python -m uv pip install -e .
31+
python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
32+
python -m uv pip install "flax>=0.4.1"
33+
python -m uv pip install "jaxlib>=0.1.65"
34+
python -m uv pip install pytest
35+
- name: Check for soft dependencies
36+
run: |
37+
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
38+
pytest tests/others/test_dependencies.py

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ limitations under the License.
3737

3838
## Installation
3939

40-
We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/), please refer to their official documentation.
40+
We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/) and [Flax](https://flax.readthedocs.io/en/latest/#installation), please refer to their official documentation.
4141

4242
### PyTorch
4343

@@ -53,6 +53,14 @@ With `conda` (maintained by the community):
5353
conda install -c conda-forge diffusers
5454
```
5555

56+
### Flax
57+
58+
With `pip` (official package):
59+
60+
```bash
61+
pip install --upgrade diffusers[flax]
62+
```
63+
5664
### Apple Silicon (M1/M2) support
5765

5866
Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggingface.co/docs/diffusers/optimization/mps) guide.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
FROM ubuntu:20.04
2+
LABEL maintainer="Hugging Face"
3+
LABEL repository="diffusers"
4+
5+
ENV DEBIAN_FRONTEND=noninteractive
6+
7+
RUN apt-get -y update \
8+
&& apt-get install -y software-properties-common \
9+
&& add-apt-repository ppa:deadsnakes/ppa
10+
11+
RUN apt install -y bash \
12+
build-essential \
13+
git \
14+
git-lfs \
15+
curl \
16+
ca-certificates \
17+
libsndfile1-dev \
18+
libgl1 \
19+
python3.10 \
20+
python3-pip \
21+
python3.10-venv && \
22+
rm -rf /var/lib/apt/lists
23+
24+
# make sure to use venv
25+
RUN python3.10 -m venv /opt/venv
26+
ENV PATH="/opt/venv/bin:$PATH"
27+
28+
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
29+
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
30+
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
31+
python3 -m uv pip install --upgrade --no-cache-dir \
32+
clu \
33+
"jax[cpu]>=0.2.16,!=0.3.2" \
34+
"flax>=0.4.1" \
35+
"jaxlib>=0.1.65" && \
36+
python3 -m uv pip install --no-cache-dir \
37+
accelerate \
38+
datasets \
39+
hf-doc-builder \
40+
huggingface-hub \
41+
Jinja2 \
42+
librosa \
43+
numpy==1.26.4 \
44+
scipy \
45+
tensorboard \
46+
transformers \
47+
hf_transfer
48+
49+
CMD ["/bin/bash"]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
FROM ubuntu:20.04
2+
LABEL maintainer="Hugging Face"
3+
LABEL repository="diffusers"
4+
5+
ENV DEBIAN_FRONTEND=noninteractive
6+
7+
RUN apt-get -y update \
8+
&& apt-get install -y software-properties-common \
9+
&& add-apt-repository ppa:deadsnakes/ppa
10+
11+
RUN apt install -y bash \
12+
build-essential \
13+
git \
14+
git-lfs \
15+
curl \
16+
ca-certificates \
17+
libsndfile1-dev \
18+
libgl1 \
19+
python3.10 \
20+
python3-pip \
21+
python3.10-venv && \
22+
rm -rf /var/lib/apt/lists
23+
24+
# make sure to use venv
25+
RUN python3.10 -m venv /opt/venv
26+
ENV PATH="/opt/venv/bin:$PATH"
27+
28+
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
29+
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
30+
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
31+
python3 -m pip install --no-cache-dir \
32+
"jax[tpu]>=0.2.16,!=0.3.2" \
33+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
34+
python3 -m uv pip install --upgrade --no-cache-dir \
35+
clu \
36+
"flax>=0.4.1" \
37+
"jaxlib>=0.1.65" && \
38+
python3 -m uv pip install --no-cache-dir \
39+
accelerate \
40+
datasets \
41+
hf-doc-builder \
42+
huggingface-hub \
43+
Jinja2 \
44+
librosa \
45+
numpy==1.26.4 \
46+
scipy \
47+
tensorboard \
48+
transformers \
49+
hf_transfer
50+
51+
CMD ["/bin/bash"]

docs/source/en/_toctree.yml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99
- local: stable_diffusion
1010
title: Basic performance
1111

12-
- title: Pipelines
12+
- title: DiffusionPipeline
1313
isExpanded: false
1414
sections:
1515
- local: using-diffusers/loading
16-
title: DiffusionPipeline
16+
title: Load pipelines
1717
- local: tutorials/autopipeline
1818
title: AutoPipeline
1919
- local: using-diffusers/custom_pipeline_overview
2020
title: Community pipelines and components
2121
- local: using-diffusers/callback
2222
title: Pipeline callbacks
2323
- local: using-diffusers/reusing_seeds
24-
title: Reproducibility
24+
title: Reproducible pipelines
2525
- local: using-diffusers/schedulers
2626
title: Load schedulers and models
2727
- local: using-diffusers/scheduler_features
@@ -62,6 +62,8 @@
6262
title: Scheduler features
6363
- local: using-diffusers/callback
6464
title: Pipeline callbacks
65+
- local: using-diffusers/reusing_seeds
66+
title: Reproducible pipelines
6567
- local: using-diffusers/image_quality
6668
title: Controlling image quality
6769

@@ -75,7 +77,7 @@
7577
- local: optimization/memory
7678
title: Reduce memory usage
7779
- local: optimization/speed-memory-optims
78-
title: Compiling and offloading quantized models
80+
title: Compile and offloading quantized models
7981
- title: Community optimizations
8082
sections:
8183
- local: optimization/pruna
@@ -192,6 +194,8 @@
192194
- title: Model accelerators and hardware
193195
isExpanded: false
194196
sections:
197+
- local: using-diffusers/stable_diffusion_jax_how_to
198+
title: JAX/Flax
195199
- local: optimization/onnx
196200
title: ONNX
197201
- local: optimization/open_vino

docs/source/en/api/models/autoencoderkl.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,15 @@ model = AutoencoderKL.from_single_file(url)
4444
## DecoderOutput
4545

4646
[[autodoc]] models.autoencoders.vae.DecoderOutput
47+
48+
## FlaxAutoencoderKL
49+
50+
[[autodoc]] FlaxAutoencoderKL
51+
52+
## FlaxAutoencoderKLOutput
53+
54+
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
55+
56+
## FlaxDecoderOutput
57+
58+
[[autodoc]] models.vae_flax.FlaxDecoderOutput

docs/source/en/api/models/controlnet.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,11 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
4040
## ControlNetOutput
4141

4242
[[autodoc]] models.controlnets.controlnet.ControlNetOutput
43+
44+
## FlaxControlNetModel
45+
46+
[[autodoc]] FlaxControlNetModel
47+
48+
## FlaxControlNetOutput
49+
50+
[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput

docs/source/en/api/models/overview.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ All models are built from the base [`ModelMixin`] class which is a [`torch.nn.Mo
1919
## ModelMixin
2020
[[autodoc]] ModelMixin
2121

22+
## FlaxModelMixin
23+
24+
[[autodoc]] FlaxModelMixin
25+
2226
## PushToHubMixin
2327

2428
[[autodoc]] utils.PushToHubMixin

docs/source/en/api/models/unet2d-cond.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@ The abstract from the paper is:
2323

2424
## UNet2DConditionOutput
2525
[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
26+
27+
## FlaxUNet2DConditionModel
28+
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
29+
30+
## FlaxUNet2DConditionOutput
31+
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput

docs/source/en/api/outputs.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ To check a specific pipeline or model output, refer to its corresponding API doc
5454

5555
[[autodoc]] pipelines.ImagePipelineOutput
5656

57+
## FlaxImagePipelineOutput
58+
59+
[[autodoc]] pipelines.pipeline_flax_utils.FlaxImagePipelineOutput
60+
5761
## AudioPipelineOutput
5862

5963
[[autodoc]] pipelines.AudioPipelineOutput

0 commit comments

Comments
 (0)