|  | 
| 19 | 19 | import jax | 
| 20 | 20 | import jax.numpy as jnp | 
| 21 | 21 | 
 | 
|  | 22 | +from ..utils import logging | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +logger = logging.get_logger(__name__) | 
|  | 26 | + | 
| 22 | 27 | 
 | 
| 23 | 28 | def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): | 
| 24 | 29 |     """Multi-head dot product attention with a limited number of queries.""" | 
| @@ -151,6 +156,11 @@ class FlaxAttention(nn.Module): | 
| 151 | 156 |     dtype: jnp.dtype = jnp.float32 | 
| 152 | 157 | 
 | 
| 153 | 158 |     def setup(self): | 
|  | 159 | +        logger.warning( | 
|  | 160 | +            "Flax classes are deprecated and will be removed in Diffusers v1. We " | 
|  | 161 | +            "recommend migrating to PyTorch classes or pinning your version of Diffusers." | 
|  | 162 | +        ) | 
|  | 163 | + | 
| 154 | 164 |         inner_dim = self.dim_head * self.heads | 
| 155 | 165 |         self.scale = self.dim_head**-0.5 | 
| 156 | 166 | 
 | 
| @@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module): | 
| 277 | 287 |     split_head_dim: bool = False | 
| 278 | 288 | 
 | 
| 279 | 289 |     def setup(self): | 
|  | 290 | +        logger.warning( | 
|  | 291 | +            "Flax classes are deprecated and will be removed in Diffusers v1. We " | 
|  | 292 | +            "recommend migrating to PyTorch classes or pinning your version of Diffusers." | 
|  | 293 | +        ) | 
|  | 294 | + | 
| 280 | 295 |         # self attention (or cross_attention if only_cross_attention is True) | 
| 281 | 296 |         self.attn1 = FlaxAttention( | 
| 282 | 297 |             self.dim, | 
| @@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module): | 
| 365 | 380 |     split_head_dim: bool = False | 
| 366 | 381 | 
 | 
| 367 | 382 |     def setup(self): | 
|  | 383 | +        logger.warning( | 
|  | 384 | +            "Flax classes are deprecated and will be removed in Diffusers v1. We " | 
|  | 385 | +            "recommend migrating to PyTorch classes or pinning your version of Diffusers." | 
|  | 386 | +        ) | 
|  | 387 | + | 
| 368 | 388 |         self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) | 
| 369 | 389 | 
 | 
| 370 | 390 |         inner_dim = self.n_heads * self.d_head | 
| @@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module): | 
| 454 | 474 |     dtype: jnp.dtype = jnp.float32 | 
| 455 | 475 | 
 | 
| 456 | 476 |     def setup(self): | 
|  | 477 | +        logger.warning( | 
|  | 478 | +            "Flax classes are deprecated and will be removed in Diffusers v1. We " | 
|  | 479 | +            "recommend migrating to PyTorch classes or pinning your version of Diffusers." | 
|  | 480 | +        ) | 
|  | 481 | + | 
| 457 | 482 |         # The second linear layer needs to be called | 
| 458 | 483 |         # net_2 for now to match the index of the Sequential layer | 
| 459 | 484 |         self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) | 
| @@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module): | 
| 484 | 509 |     dtype: jnp.dtype = jnp.float32 | 
| 485 | 510 | 
 | 
| 486 | 511 |     def setup(self): | 
|  | 512 | +        logger.warning( | 
|  | 513 | +            "Flax classes are deprecated and will be removed in Diffusers v1. We " | 
|  | 514 | +            "recommend migrating to PyTorch classes or pinning your version of Diffusers." | 
|  | 515 | +        ) | 
|  | 516 | + | 
| 487 | 517 |         inner_dim = self.dim * 4 | 
| 488 | 518 |         self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) | 
| 489 | 519 |         self.dropout_layer = nn.Dropout(rate=self.dropout) | 
|  | 
0 commit comments