Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/api/models/autoencoderkl.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ model = AutoencoderKL.from_single_file(url)

## FlaxAutoencoderKLOutput

[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
[[autodoc]] models.autoencoders.vae_flax.FlaxAutoencoderKLOutput

## FlaxDecoderOutput

[[autodoc]] models.vae_flax.FlaxDecoderOutput
[[autodoc]] models.autoencoders.vae_flax.FlaxDecoderOutput
4 changes: 2 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,10 @@


else:
_import_structure["models.autoencoders.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
[
Expand Down Expand Up @@ -791,10 +791,10 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
from .models.autoencoders.vae_flax import FlaxAutoencoderKL
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
FlaxDDIMScheduler,
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]

if is_flax_available():
_import_structure["autoencoders.vae_flax"] = ["FlaxAutoencoderKL"]

_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]


if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
Expand Down Expand Up @@ -107,9 +107,9 @@
)

if is_flax_available():
from .autoencoders import FlaxAutoencoderKL
from .controlnet_flax import FlaxControlNetModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL

else:
import sys
Expand Down
20 changes: 14 additions & 6 deletions src/diffusers/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .vq_model import VQModel
from ...utils import is_flax_available, is_torch_available


if is_torch_available():
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .vq_model import VQModel


if is_flax_available():
from .vae_flax import FlaxAutoencoderKL
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict

from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .modeling_flax_utils import FlaxModelMixin
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ..modeling_flax_utils import FlaxModelMixin


@flax.struct.dataclass
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..utils import DummyObject, requires_backends


class FlaxControlNetModel(metaclass=DummyObject):
class FlaxAutoencoderKL(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
Expand All @@ -17,7 +17,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxModelMixin(metaclass=DummyObject):
class FlaxControlNetModel(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
Expand All @@ -32,7 +32,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxUNet2DConditionModel(metaclass=DummyObject):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be removed?

class FlaxModelMixin(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
Expand All @@ -47,7 +47,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxAutoencoderKL(metaclass=DummyObject):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be removed?

class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
Expand Down