Skip to content

Commit 532f41c

Browse files
sayakpaulDN6
andauthored
Deprecate Flax support (huggingface#12151)
* start removing flax stuff. * add deprecation warning. * add warning messages. * more warnings. * remove dockerfiles. * remove more. * Update src/diffusers/models/attention_flax.py Co-authored-by: Dhruv Nair <[email protected]> * up --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 5fcd5f5 commit 532f41c

File tree

21 files changed

+186
-1848
lines changed

21 files changed

+186
-1848
lines changed

.github/workflows/pr_flax_dependency_test.yml

Lines changed: 0 additions & 38 deletions
This file was deleted.

docker/diffusers-flax-cpu/Dockerfile

Lines changed: 0 additions & 49 deletions
This file was deleted.

docker/diffusers-flax-tpu/Dockerfile

Lines changed: 0 additions & 51 deletions
This file was deleted.

src/diffusers/models/attention_flax.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
import jax
2020
import jax.numpy as jnp
2121

22+
from ..utils import logging
23+
24+
25+
logger = logging.get_logger(__name__)
26+
2227

2328
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
2429
"""Multi-head dot product attention with a limited number of queries."""
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
151156
dtype: jnp.dtype = jnp.float32
152157

153158
def setup(self):
159+
logger.warning(
160+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
161+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
162+
)
163+
154164
inner_dim = self.dim_head * self.heads
155165
self.scale = self.dim_head**-0.5
156166

@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
277287
split_head_dim: bool = False
278288

279289
def setup(self):
290+
logger.warning(
291+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
292+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
293+
)
294+
280295
# self attention (or cross_attention if only_cross_attention is True)
281296
self.attn1 = FlaxAttention(
282297
self.dim,
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
365380
split_head_dim: bool = False
366381

367382
def setup(self):
383+
logger.warning(
384+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
385+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
386+
)
387+
368388
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
369389

370390
inner_dim = self.n_heads * self.d_head
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
454474
dtype: jnp.dtype = jnp.float32
455475

456476
def setup(self):
477+
logger.warning(
478+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
479+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
480+
)
481+
457482
# The second linear layer needs to be called
458483
# net_2 for now to match the index of the Sequential layer
459484
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
484509
dtype: jnp.dtype = jnp.float32
485510

486511
def setup(self):
512+
logger.warning(
513+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
514+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
515+
)
516+
487517
inner_dim = self.dim * 4
488518
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
489519
self.dropout_layer = nn.Dropout(rate=self.dropout)

src/diffusers/models/controlnets/controlnet_flax.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flax.core.frozen_dict import FrozenDict
2121

2222
from ...configuration_utils import ConfigMixin, flax_register_to_config
23-
from ...utils import BaseOutput
23+
from ...utils import BaseOutput, logging
2424
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
2525
from ..modeling_flax_utils import FlaxModelMixin
2626
from ..unets.unet_2d_blocks_flax import (
@@ -30,6 +30,9 @@
3030
)
3131

3232

33+
logger = logging.get_logger(__name__)
34+
35+
3336
@flax.struct.dataclass
3437
class FlaxControlNetOutput(BaseOutput):
3538
"""
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
5053
dtype: jnp.dtype = jnp.float32
5154

5255
def setup(self) -> None:
56+
logger.warning(
57+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
58+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
59+
)
60+
5361
self.conv_in = nn.Conv(
5462
self.block_out_channels[0],
5563
kernel_size=(3, 3),
@@ -184,6 +192,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
184192
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
185193

186194
def setup(self) -> None:
195+
logger.warning(
196+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
197+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
198+
)
199+
187200
block_out_channels = self.block_out_channels
188201
time_embed_dim = block_out_channels[0] * 4
189202

src/diffusers/models/embeddings_flax.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
import flax.linen as nn
1717
import jax.numpy as jnp
1818

19+
from ..utils import logging
20+
21+
22+
logger = logging.get_logger(__name__)
23+
1924

2025
def get_sinusoidal_embeddings(
2126
timesteps: jnp.ndarray,
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
7681
The data type for the embedding parameters.
7782
"""
7883

84+
logger.warning(
85+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
86+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
87+
)
88+
7989
time_embed_dim: int = 32
8090
dtype: jnp.dtype = jnp.float32
8191

@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
104114
flip_sin_to_cos: bool = False
105115
freq_shift: float = 1
106116

117+
logger.warning(
118+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
119+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
120+
)
121+
107122
@nn.compact
108123
def __call__(self, timesteps):
109124
return get_sinusoidal_embeddings(

src/diffusers/models/modeling_flax_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ def from_pretrained(
290290
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
291291
```
292292
"""
293+
logger.warning(
294+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
295+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
296+
)
293297
config = kwargs.pop("config", None)
294298
cache_dir = kwargs.pop("cache_dir", None)
295299
force_download = kwargs.pop("force_download", False)

src/diffusers/models/resnet_flax.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,22 @@
1515
import jax
1616
import jax.numpy as jnp
1717

18+
from ..utils import logging
19+
20+
21+
logger = logging.get_logger(__name__)
22+
1823

1924
class FlaxUpsample2D(nn.Module):
2025
out_channels: int
2126
dtype: jnp.dtype = jnp.float32
2227

2328
def setup(self):
29+
logger.warning(
30+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
31+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
32+
)
33+
2434
self.conv = nn.Conv(
2535
self.out_channels,
2636
kernel_size=(3, 3),
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
4555
dtype: jnp.dtype = jnp.float32
4656

4757
def setup(self):
58+
logger.warning(
59+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
60+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
61+
)
62+
4863
self.conv = nn.Conv(
4964
self.out_channels,
5065
kernel_size=(3, 3),
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
6883
dtype: jnp.dtype = jnp.float32
6984

7085
def setup(self):
86+
logger.warning(
87+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
88+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
89+
)
90+
7191
out_channels = self.in_channels if self.out_channels is None else self.out_channels
7292

7393
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)

0 commit comments

Comments
 (0)