Skip to content

Commit 044d75a

Browse files
committed
more warnings.
1 parent 585fed4 commit 044d75a

File tree

7 files changed

+170
-3
lines changed

7 files changed

+170
-3
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
import jax
2020
import jax.numpy as jnp
2121

22+
from ..utils import logging
23+
24+
25+
logger = logging.get_logger(__name__)
26+
2227

2328
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
2429
"""Multi-head dot product attention with a limited number of queries."""
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
151156
dtype: jnp.dtype = jnp.float32
152157

153158
def setup(self):
159+
logger.warning(
160+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
161+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
162+
)
163+
154164
inner_dim = self.dim_head * self.heads
155165
self.scale = self.dim_head**-0.5
156166

@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
277287
split_head_dim: bool = False
278288

279289
def setup(self):
290+
logger.warning(
291+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
292+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
293+
)
294+
280295
# self attention (or cross_attention if only_cross_attention is True)
281296
self.attn1 = FlaxAttention(
282297
self.dim,
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
365380
split_head_dim: bool = False
366381

367382
def setup(self):
383+
logger.warning(
384+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
385+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
386+
)
387+
368388
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
369389

370390
inner_dim = self.n_heads * self.d_head
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
454474
dtype: jnp.dtype = jnp.float32
455475

456476
def setup(self):
477+
logger.warning(
478+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
479+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
480+
)
481+
457482
# The second linear layer needs to be called
458483
# net_2 for now to match the index of the Sequential layer
459484
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
484509
dtype: jnp.dtype = jnp.float32
485510

486511
def setup(self):
512+
logger.warning(
513+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
514+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
515+
)
516+
487517
inner_dim = self.dim * 4
488518
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
489519
self.dropout_layer = nn.Dropout(rate=self.dropout)

src/diffusers/models/controlnets/controlnet_flax.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flax.core.frozen_dict import FrozenDict
2121

2222
from ...configuration_utils import ConfigMixin, flax_register_to_config
23-
from ...utils import BaseOutput
23+
from ...utils import BaseOutput, logging
2424
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
2525
from ..modeling_flax_utils import FlaxModelMixin
2626
from ..unets.unet_2d_blocks_flax import (
@@ -30,6 +30,9 @@
3030
)
3131

3232

33+
logger = logging.get_logger(__name__)
34+
35+
3336
@flax.struct.dataclass
3437
class FlaxControlNetOutput(BaseOutput):
3538
"""
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
5053
dtype: jnp.dtype = jnp.float32
5154

5255
def setup(self) -> None:
56+
logger.warning(
57+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
58+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
59+
)
60+
5361
self.conv_in = nn.Conv(
5462
self.block_out_channels[0],
5563
kernel_size=(3, 3),
@@ -184,6 +192,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
184192
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
185193

186194
def setup(self) -> None:
195+
logger.warning(
196+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
197+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
198+
)
199+
187200
block_out_channels = self.block_out_channels
188201
time_embed_dim = block_out_channels[0] * 4
189202

src/diffusers/models/embeddings_flax.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
import flax.linen as nn
1717
import jax.numpy as jnp
1818

19+
from ..utils import logging
20+
21+
22+
logger = logging.get_logger(__name__)
23+
1924

2025
def get_sinusoidal_embeddings(
2126
timesteps: jnp.ndarray,
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
7681
The data type for the embedding parameters.
7782
"""
7883

84+
logger.warning(
85+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
86+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
87+
)
88+
7989
time_embed_dim: int = 32
8090
dtype: jnp.dtype = jnp.float32
8191

@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
104114
flip_sin_to_cos: bool = False
105115
freq_shift: float = 1
106116

117+
logger.warning(
118+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
119+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
120+
)
121+
107122
@nn.compact
108123
def __call__(self, timesteps):
109124
return get_sinusoidal_embeddings(

src/diffusers/models/resnet_flax.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,22 @@
1515
import jax
1616
import jax.numpy as jnp
1717

18+
from ..utils import logging
19+
20+
21+
logger = logging.get_logger(__name__)
22+
1823

1924
class FlaxUpsample2D(nn.Module):
2025
out_channels: int
2126
dtype: jnp.dtype = jnp.float32
2227

2328
def setup(self):
29+
logger.warning(
30+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
31+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
32+
)
33+
2434
self.conv = nn.Conv(
2535
self.out_channels,
2636
kernel_size=(3, 3),
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
4555
dtype: jnp.dtype = jnp.float32
4656

4757
def setup(self):
58+
logger.warning(
59+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
60+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
61+
)
62+
4863
self.conv = nn.Conv(
4964
self.out_channels,
5065
kernel_size=(3, 3),
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
6883
dtype: jnp.dtype = jnp.float32
6984

7085
def setup(self):
86+
logger.warning(
87+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
88+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
89+
)
90+
7191
out_channels = self.in_channels if self.out_channels is None else self.out_channels
7292

7393
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)

src/diffusers/models/unets/unet_2d_blocks_flax.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
import flax.linen as nn
1616
import jax.numpy as jnp
1717

18+
from ...utils import logging
1819
from ..attention_flax import FlaxTransformer2DModel
1920
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
2021

2122

23+
logger = logging.get_logger(__name__)
24+
25+
2226
class FlaxCrossAttnDownBlock2D(nn.Module):
2327
r"""
2428
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
6064
transformer_layers_per_block: int = 1
6165

6266
def setup(self):
67+
logger.warning(
68+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
69+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
70+
)
71+
6372
resnets = []
6473
attentions = []
6574

@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
135144
dtype: jnp.dtype = jnp.float32
136145

137146
def setup(self):
147+
logger.warning(
148+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
149+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
150+
)
151+
138152
resnets = []
139153

140154
for i in range(self.num_layers):
@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
208222
transformer_layers_per_block: int = 1
209223

210224
def setup(self):
225+
logger.warning(
226+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
227+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
228+
)
229+
211230
resnets = []
212231
attentions = []
213232

@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
288307
dtype: jnp.dtype = jnp.float32
289308

290309
def setup(self):
310+
logger.warning(
311+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
312+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
313+
)
314+
291315
resnets = []
292316

293317
for i in range(self.num_layers):
@@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
356380
transformer_layers_per_block: int = 1
357381

358382
def setup(self):
383+
logger.warning(
384+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
385+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
386+
)
387+
359388
# there is always at least one resnet
360389
resnets = [
361390
FlaxResnetBlock2D(

src/diffusers/models/unets/unet_2d_condition_flax.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flax.core.frozen_dict import FrozenDict
2121

2222
from ...configuration_utils import ConfigMixin, flax_register_to_config
23-
from ...utils import BaseOutput
23+
from ...utils import BaseOutput, logging
2424
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
2525
from ..modeling_flax_utils import FlaxModelMixin
2626
from .unet_2d_blocks_flax import (
@@ -32,6 +32,9 @@
3232
)
3333

3434

35+
logger = logging.get_logger(__name__)
36+
37+
3538
@flax.struct.dataclass
3639
class FlaxUNet2DConditionOutput(BaseOutput):
3740
"""
@@ -163,6 +166,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
163166
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
164167

165168
def setup(self) -> None:
169+
logger.warning(
170+
"Flax classes are deprecated and will be removed in Diffusers v1. We "
171+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
172+
)
173+
166174
block_out_channels = self.block_out_channels
167175
time_embed_dim = block_out_channels[0] * 4
168176

0 commit comments

Comments
 (0)