Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions aion/codecs/catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import OrderedDict
from typing import Type, Optional, Dict
from typing import Dict, Optional, Type

from huggingface_hub import PyTorchModelHubMixin
import torch
from jaxtyping import Float
from torch import Tensor
Expand All @@ -13,10 +12,11 @@
IdentityQuantizer,
ScalarReservoirQuantizer,
)
from aion.codecs.utils import CodecPytorchHubMixin
from aion.modalities import LegacySurveyCatalog


class CatalogCodec(Codec, PyTorchModelHubMixin):
class CatalogCodec(Codec, CodecPytorchHubMixin):
"""Codec for catalog quantities.

A codec that embeds catalog quantities through an identity mapping. A
Expand Down
169 changes: 41 additions & 128 deletions aion/codecs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,132 +75,45 @@ class CodecHFConfig:
repo_id: str


CODEC_CONFIG = {
Image: CodecHFConfig(
codec_class=ImageCodec, repo_id="polymathic-ai/aion-image-codec"
),
Spectrum: CodecHFConfig(
codec_class=SpectrumCodec, repo_id="polymathic-ai/aion-spectrum-codec"
),
LegacySurveyCatalog: CodecHFConfig(
codec_class=CatalogCodec, repo_id="polymathic-ai/aion-catalog-codec"
),
LegacySurveySegmentationMap: CodecHFConfig(
codec_class=ScalarFieldCodec, repo_id="polymathic-ai/aion-scalar-field-codec"
),
# Scalar modalities
# LogScalarCodec
LegacySurveyFluxG: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-g-codec"
),
LegacySurveyFluxR: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-r-codec"
),
LegacySurveyFluxI: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-i-codec"
),
LegacySurveyFluxZ: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-z-codec"
),
LegacySurveyFluxW1: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w1-codec"
),
LegacySurveyFluxW2: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w2-codec"
),
LegacySurveyFluxW3: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w3-codec"
),
LegacySurveyFluxW4: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w4-codec"
),
LegacySurveyShapeR: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-shape-r-codec"
),
GaiaFluxG: CodecHFConfig(
codec_class=LogScalarCodec,
repo_id="polymathic-ai/aion-scalar-phot-g-mean-flux-codec",
),
GaiaFluxBp: CodecHFConfig(
codec_class=LogScalarCodec,
repo_id="polymathic-ai/aion-scalar-phot-bp-mean-flux-codec",
),
GaiaFluxRp: CodecHFConfig(
codec_class=LogScalarCodec,
repo_id="polymathic-ai/aion-scalar-phot-rp-mean-flux-codec",
),
GaiaParallax: CodecHFConfig(
codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-parallax-codec"
),
# ScalarCodec
LegacySurveyShapeE1: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-shape-e1-codec"
),
LegacySurveyShapeE2: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-shape-e2-codec"
),
LegacySurveyEBV: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-ebv-codec"
),
HSCMagG: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-g-cmodel-mag-codec"
),
HSCMagR: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-r-cmodel-mag-codec"
),
HSCMagI: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-i-cmodel-mag-codec"
),
HSCMagZ: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-z-cmodel-mag-codec"
),
HSCMagY: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-y-cmodel-mag-codec"
),
HSCShape11: CodecHFConfig(
codec_class=ScalarCodec,
repo_id="polymathic-ai/aion-scalar-i-sdssshape-shape11-codec",
),
HSCShape22: CodecHFConfig(
codec_class=ScalarCodec,
repo_id="polymathic-ai/aion-scalar-i-sdssshape-shape22-codec",
),
HSCShape12: CodecHFConfig(
codec_class=ScalarCodec,
repo_id="polymathic-ai/aion-scalar-i-sdssshape-shape12-codec",
),
HSCAG: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-g-codec"
),
HSCAR: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-r-codec"
),
HSCAI: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-i-codec"
),
HSCAZ: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-z-codec"
),
HSCAY: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-y-codec"
),
Ra: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-ra-codec"
),
Dec: CodecHFConfig(
codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-dec-codec"
),
# MultiScalarCodec
GaiaXpBp: CodecHFConfig(
codec_class=MultiScalarCodec,
repo_id="polymathic-ai/aion-scalar-bp-coefficients-codec",
),
GaiaXpRp: CodecHFConfig(
codec_class=MultiScalarCodec,
repo_id="polymathic-ai/aion-scalar-rp-coefficients-codec",
),
# GridScalarCodec
Z: CodecHFConfig(
codec_class=GridScalarCodec, repo_id="polymathic-ai/aion-scalar-z-codec"
),
MODALITY_CODEC_MAPPING = {
Image: ImageCodec,
Spectrum: SpectrumCodec,
LegacySurveyCatalog: CatalogCodec,
LegacySurveySegmentationMap: ScalarFieldCodec,
LegacySurveyFluxG: LogScalarCodec,
LegacySurveyFluxR: LogScalarCodec,
LegacySurveyFluxI: LogScalarCodec,
LegacySurveyFluxZ: LogScalarCodec,
LegacySurveyFluxW1: LogScalarCodec,
LegacySurveyFluxW2: LogScalarCodec,
LegacySurveyFluxW3: LogScalarCodec,
LegacySurveyFluxW4: LogScalarCodec,
LegacySurveyShapeR: LogScalarCodec,
GaiaFluxG: LogScalarCodec,
GaiaFluxBp: LogScalarCodec,
GaiaFluxRp: LogScalarCodec,
GaiaParallax: LogScalarCodec,
LegacySurveyShapeE1: ScalarCodec,
LegacySurveyShapeE2: ScalarCodec,
LegacySurveyEBV: ScalarCodec,
HSCMagG: ScalarCodec,
HSCMagR: ScalarCodec,
HSCMagI: ScalarCodec,
HSCMagZ: ScalarCodec,
HSCMagY: ScalarCodec,
HSCShape11: ScalarCodec,
HSCShape22: ScalarCodec,
HSCShape12: ScalarCodec,
HSCAG: ScalarCodec,
HSCAR: ScalarCodec,
HSCAI: ScalarCodec,
HSCAZ: ScalarCodec,
HSCAY: ScalarCodec,
Ra: ScalarCodec,
Dec: ScalarCodec,
GaiaXpBp: MultiScalarCodec,
GaiaXpRp: MultiScalarCodec,
Z: GridScalarCodec,
}

HF_REPO_ID = "polymathic-ai/aion-base"
4 changes: 2 additions & 2 deletions aion/codecs/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from huggingface_hub import PyTorchModelHubMixin
from jaxtyping import Float
from torch import Tensor
from typing import Type, Optional, List
Expand All @@ -16,6 +15,7 @@
Clamp,
)
from aion.codecs.preprocessing.band_to_index import BAND_TO_INDEX
from aion.codecs.utils import CodecPytorchHubMixin


class AutoencoderImageCodec(Codec):
Expand Down Expand Up @@ -165,7 +165,7 @@ def decode(
return super().decode(z, bands=bands)


class ImageCodec(AutoencoderImageCodec, PyTorchModelHubMixin):
class ImageCodec(AutoencoderImageCodec, CodecPytorchHubMixin):
def __init__(
self,
quantizer_levels: List[int],
Expand Down
25 changes: 13 additions & 12 deletions aion/codecs/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch

from aion.codecs.base import Codec
from aion.codecs.config import CODEC_CONFIG, CodecType
from aion.modalities import BaseModality, Modality
from aion.codecs.config import MODALITY_CODEC_MAPPING, CodecType, HF_REPO_ID
from aion.modalities import Modality


class ModalityTypeError(TypeError):
Expand All @@ -35,7 +35,9 @@ def __init__(self, device: str | torch.device = "cpu"):

@staticmethod
@lru_cache
def _load_codec_from_hf(codec_class: CodecType, hf_codec_repo_id: str) -> Codec:
def _load_codec_from_hf(
codec_class: CodecType, modality_type: type[Modality]
) -> Codec:
"""Load a codec from HuggingFace.
Although HF download is already cached,
the method is cached to avoid reloading the same codec.
Expand All @@ -47,29 +49,28 @@ def _load_codec_from_hf(codec_class: CodecType, hf_codec_repo_id: str) -> Codec:
Returns:
The loaded codec
"""
codec = codec_class.from_pretrained(hf_codec_repo_id)

codec = codec_class.from_pretrained(HF_REPO_ID, modality=modality_type)
codec = codec.eval()
return codec

@lru_cache
def _load_codec(self, modality_type: type[BaseModality]) -> Codec:
def _load_codec(self, modality_type: type[Modality]) -> Codec:
"""Load a codec for the given modality type."""
# Look up configuration in CODEC_CONFIG
if modality_type in CODEC_CONFIG:
config = CODEC_CONFIG[modality_type]
if modality_type in MODALITY_CODEC_MAPPING:
codec_class = MODALITY_CODEC_MAPPING[modality_type]
elif (
hasattr(modality_type, "__base__")
and modality_type.__base__ in CODEC_CONFIG
and modality_type.__base__ in MODALITY_CODEC_MAPPING
):
config = CODEC_CONFIG[modality_type.__base__]
codec_class = MODALITY_CODEC_MAPPING[modality_type.__base__]
else:
raise ModalityTypeError(
f"No codec configuration found for modality type: {modality_type.__name__}"
)

codec_class = config.codec_class
hf_codec_repo_id = config.repo_id
codec = self._load_codec_from_hf(codec_class, hf_codec_repo_id)
codec = self._load_codec_from_hf(codec_class, modality_type)

return codec

Expand Down
4 changes: 2 additions & 2 deletions aion/codecs/scalar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Type, Optional, Dict, Any

from huggingface_hub import PyTorchModelHubMixin
from jaxtyping import Float
from torch import Tensor

Expand All @@ -11,10 +10,11 @@
MultiScalarCompressedReservoirQuantizer,
)
from aion.codecs.base import Codec
from aion.codecs.utils import CodecPytorchHubMixin
from aion.modalities import Scalar, ScalarModalities


class BaseScalarIdentityCodec(Codec, PyTorchModelHubMixin):
class BaseScalarIdentityCodec(Codec, CodecPytorchHubMixin):
"""Codec for scalar quantities.

A codec that embeds scalar quantities through an identity mapping. A
Expand Down
13 changes: 6 additions & 7 deletions aion/codecs/scalar_field.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from functools import reduce
from typing import Callable, Optional, Type

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Callable, Optional, Type
from jaxtyping import Float
from torch import Tensor

from huggingface_hub import PyTorchModelHubMixin
from aion.codecs.utils import CodecPytorchHubMixin
from aion.modalities import LegacySurveySegmentationMap

from .base import Codec
from .quantizers import Quantizer, FiniteScalarQuantizer
from .modules.convblocks import Encoder2d, Decoder2d
from .modules.convblocks import Decoder2d, Encoder2d
from .modules.ema import ModelEmaV2
from aion.modalities import LegacySurveySegmentationMap
from .preprocessing.image import CenterCrop
from .quantizers import FiniteScalarQuantizer, Quantizer


def _deep_get(dictionary, path, default=None):
Expand Down Expand Up @@ -171,7 +170,7 @@ def _output_activation(
# ======================================================================================


class ScalarFieldCodec(AutoencoderScalarFieldCodec, PyTorchModelHubMixin):
class ScalarFieldCodec(AutoencoderScalarFieldCodec, CodecPytorchHubMixin):
"""Convolutional autoencoder codec for scalar fields."""

def __init__(
Expand Down
11 changes: 6 additions & 5 deletions aion/codecs/spectrum.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Type

import torch
from huggingface_hub import PyTorchModelHubMixin
from jaxtyping import Float, Real
from typing import Type

from aion.modalities import Spectrum
from aion.codecs.base import Codec
from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d
from aion.codecs.modules.spectrum import LatentSpectralGrid
from aion.codecs.quantizers import LucidrainsLFQ, Quantizer, ScalarLinearQuantizer
from aion.codecs.base import Codec
from aion.codecs.utils import CodecPytorchHubMixin
from aion.modalities import Spectrum


class AutoencoderSpectrumCodec(Codec):
Expand Down Expand Up @@ -179,7 +180,7 @@ def _decode(
)


class SpectrumCodec(AutoencoderSpectrumCodec, PyTorchModelHubMixin):
class SpectrumCodec(AutoencoderSpectrumCodec, CodecPytorchHubMixin):
"""Spectrum codec based on convnext blocks."""

def __init__(
Expand Down
Loading
Loading