diff --git a/aion/codecs/catalog.py b/aion/codecs/catalog.py index a4dd2c1..8aaa23f 100644 --- a/aion/codecs/catalog.py +++ b/aion/codecs/catalog.py @@ -1,7 +1,6 @@ from collections import OrderedDict -from typing import Type, Optional, Dict +from typing import Dict, Optional, Type -from huggingface_hub import PyTorchModelHubMixin import torch from jaxtyping import Float from torch import Tensor @@ -13,10 +12,11 @@ IdentityQuantizer, ScalarReservoirQuantizer, ) +from aion.codecs.utils import CodecPytorchHubMixin from aion.modalities import LegacySurveyCatalog -class CatalogCodec(Codec, PyTorchModelHubMixin): +class CatalogCodec(Codec, CodecPytorchHubMixin): """Codec for catalog quantities. A codec that embeds catalog quantities through an identity mapping. A diff --git a/aion/codecs/config.py b/aion/codecs/config.py index 1d545b3..bb3734c 100644 --- a/aion/codecs/config.py +++ b/aion/codecs/config.py @@ -75,132 +75,45 @@ class CodecHFConfig: repo_id: str -CODEC_CONFIG = { - Image: CodecHFConfig( - codec_class=ImageCodec, repo_id="polymathic-ai/aion-image-codec" - ), - Spectrum: CodecHFConfig( - codec_class=SpectrumCodec, repo_id="polymathic-ai/aion-spectrum-codec" - ), - LegacySurveyCatalog: CodecHFConfig( - codec_class=CatalogCodec, repo_id="polymathic-ai/aion-catalog-codec" - ), - LegacySurveySegmentationMap: CodecHFConfig( - codec_class=ScalarFieldCodec, repo_id="polymathic-ai/aion-scalar-field-codec" - ), - # Scalar modalities - # LogScalarCodec - LegacySurveyFluxG: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-g-codec" - ), - LegacySurveyFluxR: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-r-codec" - ), - LegacySurveyFluxI: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-i-codec" - ), - LegacySurveyFluxZ: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-z-codec" - ), - LegacySurveyFluxW1: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w1-codec" - ), - LegacySurveyFluxW2: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w2-codec" - ), - LegacySurveyFluxW3: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w3-codec" - ), - LegacySurveyFluxW4: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-flux-w4-codec" - ), - LegacySurveyShapeR: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-shape-r-codec" - ), - GaiaFluxG: CodecHFConfig( - codec_class=LogScalarCodec, - repo_id="polymathic-ai/aion-scalar-phot-g-mean-flux-codec", - ), - GaiaFluxBp: CodecHFConfig( - codec_class=LogScalarCodec, - repo_id="polymathic-ai/aion-scalar-phot-bp-mean-flux-codec", - ), - GaiaFluxRp: CodecHFConfig( - codec_class=LogScalarCodec, - repo_id="polymathic-ai/aion-scalar-phot-rp-mean-flux-codec", - ), - GaiaParallax: CodecHFConfig( - codec_class=LogScalarCodec, repo_id="polymathic-ai/aion-scalar-parallax-codec" - ), - # ScalarCodec - LegacySurveyShapeE1: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-shape-e1-codec" - ), - LegacySurveyShapeE2: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-shape-e2-codec" - ), - LegacySurveyEBV: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-ebv-codec" - ), - HSCMagG: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-g-cmodel-mag-codec" - ), - HSCMagR: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-r-cmodel-mag-codec" - ), - HSCMagI: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-i-cmodel-mag-codec" - ), - HSCMagZ: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-z-cmodel-mag-codec" - ), - HSCMagY: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-y-cmodel-mag-codec" - ), - HSCShape11: CodecHFConfig( - codec_class=ScalarCodec, - repo_id="polymathic-ai/aion-scalar-i-sdssshape-shape11-codec", - ), - HSCShape22: CodecHFConfig( - codec_class=ScalarCodec, - repo_id="polymathic-ai/aion-scalar-i-sdssshape-shape22-codec", - ), - HSCShape12: CodecHFConfig( - codec_class=ScalarCodec, - repo_id="polymathic-ai/aion-scalar-i-sdssshape-shape12-codec", - ), - HSCAG: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-g-codec" - ), - HSCAR: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-r-codec" - ), - HSCAI: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-i-codec" - ), - HSCAZ: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-z-codec" - ), - HSCAY: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-a-y-codec" - ), - Ra: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-ra-codec" - ), - Dec: CodecHFConfig( - codec_class=ScalarCodec, repo_id="polymathic-ai/aion-scalar-dec-codec" - ), - # MultiScalarCodec - GaiaXpBp: CodecHFConfig( - codec_class=MultiScalarCodec, - repo_id="polymathic-ai/aion-scalar-bp-coefficients-codec", - ), - GaiaXpRp: CodecHFConfig( - codec_class=MultiScalarCodec, - repo_id="polymathic-ai/aion-scalar-rp-coefficients-codec", - ), - # GridScalarCodec - Z: CodecHFConfig( - codec_class=GridScalarCodec, repo_id="polymathic-ai/aion-scalar-z-codec" - ), +MODALITY_CODEC_MAPPING = { + Image: ImageCodec, + Spectrum: SpectrumCodec, + LegacySurveyCatalog: CatalogCodec, + LegacySurveySegmentationMap: ScalarFieldCodec, + LegacySurveyFluxG: LogScalarCodec, + LegacySurveyFluxR: LogScalarCodec, + LegacySurveyFluxI: LogScalarCodec, + LegacySurveyFluxZ: LogScalarCodec, + LegacySurveyFluxW1: LogScalarCodec, + LegacySurveyFluxW2: LogScalarCodec, + LegacySurveyFluxW3: LogScalarCodec, + LegacySurveyFluxW4: LogScalarCodec, + LegacySurveyShapeR: LogScalarCodec, + GaiaFluxG: LogScalarCodec, + GaiaFluxBp: LogScalarCodec, + GaiaFluxRp: LogScalarCodec, + GaiaParallax: LogScalarCodec, + LegacySurveyShapeE1: ScalarCodec, + LegacySurveyShapeE2: ScalarCodec, + LegacySurveyEBV: ScalarCodec, + HSCMagG: ScalarCodec, + HSCMagR: ScalarCodec, + HSCMagI: ScalarCodec, + HSCMagZ: ScalarCodec, + HSCMagY: ScalarCodec, + HSCShape11: ScalarCodec, + HSCShape22: ScalarCodec, + HSCShape12: ScalarCodec, + HSCAG: ScalarCodec, + HSCAR: ScalarCodec, + HSCAI: ScalarCodec, + HSCAZ: ScalarCodec, + HSCAY: ScalarCodec, + Ra: ScalarCodec, + Dec: ScalarCodec, + GaiaXpBp: MultiScalarCodec, + GaiaXpRp: MultiScalarCodec, + Z: GridScalarCodec, } + +HF_REPO_ID = "polymathic-ai/aion-base" diff --git a/aion/codecs/image.py b/aion/codecs/image.py index 1b0b0ef..5f67609 100644 --- a/aion/codecs/image.py +++ b/aion/codecs/image.py @@ -1,5 +1,4 @@ import torch -from huggingface_hub import PyTorchModelHubMixin from jaxtyping import Float from torch import Tensor from typing import Type, Optional, List @@ -16,6 +15,7 @@ Clamp, ) from aion.codecs.preprocessing.band_to_index import BAND_TO_INDEX +from aion.codecs.utils import CodecPytorchHubMixin class AutoencoderImageCodec(Codec): @@ -165,7 +165,7 @@ def decode( return super().decode(z, bands=bands) -class ImageCodec(AutoencoderImageCodec, PyTorchModelHubMixin): +class ImageCodec(AutoencoderImageCodec, CodecPytorchHubMixin): def __init__( self, quantizer_levels: List[int], diff --git a/aion/codecs/manager.py b/aion/codecs/manager.py index b7dc932..d04b363 100644 --- a/aion/codecs/manager.py +++ b/aion/codecs/manager.py @@ -9,8 +9,8 @@ import torch from aion.codecs.base import Codec -from aion.codecs.config import CODEC_CONFIG, CodecType -from aion.modalities import BaseModality, Modality +from aion.codecs.config import MODALITY_CODEC_MAPPING, CodecType, HF_REPO_ID +from aion.modalities import Modality class ModalityTypeError(TypeError): @@ -35,7 +35,9 @@ def __init__(self, device: str | torch.device = "cpu"): @staticmethod @lru_cache - def _load_codec_from_hf(codec_class: CodecType, hf_codec_repo_id: str) -> Codec: + def _load_codec_from_hf( + codec_class: CodecType, modality_type: type[Modality] + ) -> Codec: """Load a codec from HuggingFace. Although HF download is already cached, the method is cached to avoid reloading the same codec. @@ -47,29 +49,28 @@ def _load_codec_from_hf(codec_class: CodecType, hf_codec_repo_id: str) -> Codec: Returns: The loaded codec """ - codec = codec_class.from_pretrained(hf_codec_repo_id) + + codec = codec_class.from_pretrained(HF_REPO_ID, modality=modality_type) codec = codec.eval() return codec @lru_cache - def _load_codec(self, modality_type: type[BaseModality]) -> Codec: + def _load_codec(self, modality_type: type[Modality]) -> Codec: """Load a codec for the given modality type.""" # Look up configuration in CODEC_CONFIG - if modality_type in CODEC_CONFIG: - config = CODEC_CONFIG[modality_type] + if modality_type in MODALITY_CODEC_MAPPING: + codec_class = MODALITY_CODEC_MAPPING[modality_type] elif ( hasattr(modality_type, "__base__") - and modality_type.__base__ in CODEC_CONFIG + and modality_type.__base__ in MODALITY_CODEC_MAPPING ): - config = CODEC_CONFIG[modality_type.__base__] + codec_class = MODALITY_CODEC_MAPPING[modality_type.__base__] else: raise ModalityTypeError( f"No codec configuration found for modality type: {modality_type.__name__}" ) - codec_class = config.codec_class - hf_codec_repo_id = config.repo_id - codec = self._load_codec_from_hf(codec_class, hf_codec_repo_id) + codec = self._load_codec_from_hf(codec_class, modality_type) return codec diff --git a/aion/codecs/scalar.py b/aion/codecs/scalar.py index 8308fb6..e40b235 100644 --- a/aion/codecs/scalar.py +++ b/aion/codecs/scalar.py @@ -1,6 +1,5 @@ from typing import Type, Optional, Dict, Any -from huggingface_hub import PyTorchModelHubMixin from jaxtyping import Float from torch import Tensor @@ -11,10 +10,11 @@ MultiScalarCompressedReservoirQuantizer, ) from aion.codecs.base import Codec +from aion.codecs.utils import CodecPytorchHubMixin from aion.modalities import Scalar, ScalarModalities -class BaseScalarIdentityCodec(Codec, PyTorchModelHubMixin): +class BaseScalarIdentityCodec(Codec, CodecPytorchHubMixin): """Codec for scalar quantities. A codec that embeds scalar quantities through an identity mapping. A diff --git a/aion/codecs/scalar_field.py b/aion/codecs/scalar_field.py index 875e3c5..c736a9b 100644 --- a/aion/codecs/scalar_field.py +++ b/aion/codecs/scalar_field.py @@ -1,21 +1,20 @@ from functools import reduce +from typing import Callable, Optional, Type import torch import torch.nn as nn import torch.nn.functional as F - -from typing import Callable, Optional, Type from jaxtyping import Float from torch import Tensor -from huggingface_hub import PyTorchModelHubMixin +from aion.codecs.utils import CodecPytorchHubMixin +from aion.modalities import LegacySurveySegmentationMap from .base import Codec -from .quantizers import Quantizer, FiniteScalarQuantizer -from .modules.convblocks import Encoder2d, Decoder2d +from .modules.convblocks import Decoder2d, Encoder2d from .modules.ema import ModelEmaV2 -from aion.modalities import LegacySurveySegmentationMap from .preprocessing.image import CenterCrop +from .quantizers import FiniteScalarQuantizer, Quantizer def _deep_get(dictionary, path, default=None): @@ -171,7 +170,7 @@ def _output_activation( # ====================================================================================== -class ScalarFieldCodec(AutoencoderScalarFieldCodec, PyTorchModelHubMixin): +class ScalarFieldCodec(AutoencoderScalarFieldCodec, CodecPytorchHubMixin): """Convolutional autoencoder codec for scalar fields.""" def __init__( diff --git a/aion/codecs/spectrum.py b/aion/codecs/spectrum.py index 67b75ea..1f4c0c3 100644 --- a/aion/codecs/spectrum.py +++ b/aion/codecs/spectrum.py @@ -1,13 +1,14 @@ +from typing import Type + import torch -from huggingface_hub import PyTorchModelHubMixin from jaxtyping import Float, Real -from typing import Type -from aion.modalities import Spectrum +from aion.codecs.base import Codec from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d from aion.codecs.modules.spectrum import LatentSpectralGrid from aion.codecs.quantizers import LucidrainsLFQ, Quantizer, ScalarLinearQuantizer -from aion.codecs.base import Codec +from aion.codecs.utils import CodecPytorchHubMixin +from aion.modalities import Spectrum class AutoencoderSpectrumCodec(Codec): @@ -179,7 +180,7 @@ def _decode( ) -class SpectrumCodec(AutoencoderSpectrumCodec, PyTorchModelHubMixin): +class SpectrumCodec(AutoencoderSpectrumCodec, CodecPytorchHubMixin): """Spectrum codec based on convnext blocks.""" def __init__( diff --git a/aion/codecs/utils.py b/aion/codecs/utils.py new file mode 100644 index 0000000..4b0c1a6 --- /dev/null +++ b/aion/codecs/utils.py @@ -0,0 +1,155 @@ +from contextlib import contextmanager +from threading import local +from typing import Optional + +from huggingface_hub import hub_mixin + +from aion.codecs.base import Codec +from aion.modalities import Modality + +ORIGINAL_CONFIG_NAME = hub_mixin.constants.CONFIG_NAME +ORIGINAL_PYTORCH_WEIGHTS_NAME = hub_mixin.constants.PYTORCH_WEIGHTS_NAME +ORIGINAL_SAFETENSORS_SINGLE_FILE = hub_mixin.constants.SAFETENSORS_SINGLE_FILE + +# Thread-local storage for codec context +_thread_local = local() + + +@contextmanager +def _codec_path_context(modality: type[Modality]): + """Thread-safe context manager for temporarily overriding HuggingFace constants. + + Args: + modality: The modality type to create paths for + + Yields: + None + """ + # Store original values + original_config = hub_mixin.constants.CONFIG_NAME + original_weights = hub_mixin.constants.PYTORCH_WEIGHTS_NAME + original_safetensors = hub_mixin.constants.SAFETENSORS_SINGLE_FILE + + try: + # Set codec-specific paths + hub_mixin.constants.CONFIG_NAME = ( + f"codecs/{modality.name}/{ORIGINAL_CONFIG_NAME}" + ) + hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ( + f"codecs/{modality.name}/{ORIGINAL_PYTORCH_WEIGHTS_NAME}" + ) + hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ( + f"codecs/{modality.name}/{ORIGINAL_SAFETENSORS_SINGLE_FILE}" + ) + yield + finally: + # Always restore original values + hub_mixin.constants.CONFIG_NAME = original_config + hub_mixin.constants.PYTORCH_WEIGHTS_NAME = original_weights + hub_mixin.constants.SAFETENSORS_SINGLE_FILE = original_safetensors + + +def _validate_modality(modality: type[Modality]) -> None: + """Validate that the modality is properly configured. + + Args: + modality: The modality type to validate + + Raises: + ValueError: If the modality is invalid + """ + if not isinstance(modality, type): + raise ValueError(f"Expected modality to be a type, got {type(modality)}") + + if not issubclass(modality, Modality): + raise ValueError(f"Modality {modality} must be a subclass of Modality") + + if not hasattr(modality, "name") or not isinstance(modality.name, str): + raise ValueError( + f"Modality {modality} must have a 'name' class attribute of type str" + ) + + if not modality.name.strip(): + raise ValueError(f"Modality {modality} name cannot be empty") + + +class CodecPytorchHubMixin(hub_mixin.PyTorchModelHubMixin): + """Mixin for PyTorch models that correspond to codecs. + Codec don't have their own model repo. + Instead they lie in the transformer model repo as subfolders. + """ + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + modality: type[Modality], + *model_args, + **kwargs, + ): + """Load a codec model from a pretrained model repository. + + Args: + pretrained_model_name_or_path (str): The name or path of the pretrained + model repository. + modality (type[Modality]): The modality type for this codec. + *model_args: Additional positional arguments to pass to the model + constructor. + **kwargs: Additional keyword arguments to pass to the model + constructor. + + Returns: + The loaded codec model. + + Raises: + ValueError: If the class is not a codec subclass or modality is invalid. + """ + if not issubclass(cls, Codec): + raise ValueError("Only codec classes can be loaded using this method.") + + # Validate modality + _validate_modality(modality) + + # Use thread-safe context manager to override paths + with _codec_path_context(modality): + model = super().from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + + # Store modality reference on the model instance for later use + model._modality = modality + return model + + def save_pretrained( + self, save_directory, modality: Optional[type[Modality]] = None, *args, **kwargs + ): + """Save the codec model to a pretrained model repository. + + Args: + save_directory (str): The directory to save the model to. + modality (Optional[type[Modality]]): The modality type for this codec. + If not provided, will use the modality stored during from_pretrained. + *args: Additional positional arguments to pass to the save method. + **kwargs: Additional keyword arguments to pass to the save method. + + Raises: + ValueError: If the instance is not a codec or modality cannot be determined. + """ + if not issubclass(self.__class__, Codec): + raise ValueError("Only codec instances can be saved using this method.") + + # Determine modality to use + if modality is not None: + _validate_modality(modality) + target_modality = modality + elif hasattr(self, "_modality"): + target_modality = self._modality + else: + raise ValueError( + "No modality specified. Either provide modality parameter or " + "load the codec using from_pretrained() which stores the modality." + ) + + # Construct the path to the codec subfolder + codec_path = f"{save_directory}/codecs/{target_modality.name}" + super().save_pretrained(codec_path, *args, **kwargs) diff --git a/aion/modalities.py b/aion/modalities.py index 951d22a..9bfdb44 100644 --- a/aion/modalities.py +++ b/aion/modalities.py @@ -49,23 +49,20 @@ ] -class BaseModality(ABC): +class Modality(ABC): """Base class for all modality data types.""" - -class Modality(BaseModality, ABC): - """Base class for all token modalities.""" - token_key: ClassVar[str] = "" @dataclass -class Image(BaseModality): +class Image(Modality): """Base class for image modality data. This is an abstract base class. Use LegacySurveyImage or HSCImage instead. """ + name: ClassVar[str] = "image" flux: Float[Tensor, " batch num_bands height width"] bands: list[str] @@ -74,25 +71,26 @@ def __repr__(self) -> str: return repr_str -class HSCImage(Image, Modality): +class HSCImage(Image): """HSC image modality data.""" token_key: ClassVar[str] = "tok_image_hsc" -class LegacySurveyImage(Image, Modality): +class LegacySurveyImage(Image): """Legacy Survey image modality data.""" token_key: ClassVar[str] = "tok_image" @dataclass -class Spectrum(BaseModality): +class Spectrum(Modality): """Base class for spectrum modality data. This is an abstract base class. Use DESISpectrum or SDSSSpectrum instead. """ + name: ClassVar[str] = "spectrum" flux: Float[Tensor, " batch length"] ivar: Float[Tensor, " batch length"] mask: Bool[Tensor, " batch length"] @@ -107,13 +105,13 @@ def __repr__(self) -> str: return repr_str -class DESISpectrum(Spectrum, Modality): +class DESISpectrum(Spectrum): """DESI spectrum modality data.""" token_key: ClassVar[str] = "tok_spectrum_desi" -class SDSSSpectrum(Spectrum, Modality): +class SDSSSpectrum(Spectrum): """SDSS spectrum modality data.""" token_key: ClassVar[str] = "tok_spectrum_sdss" @@ -127,6 +125,7 @@ class LegacySurveyCatalog(Modality): Represents a catalog of scalar values from the Legacy Survey. """ + name: ClassVar[str] = "catalog" X: Int[Tensor, " batch n"] Y: Int[Tensor, " batch n"] SHAPE_E1: Float[Tensor, " batch n"] @@ -142,11 +141,12 @@ class LegacySurveySegmentationMap(Modality): Represents 2D segmentation maps built from Legacy Survey detections. """ + name: ClassVar[str] = "segmentation_map" field: Float[Tensor, " batch height width"] token_key: ClassVar[str] = "tok_segmap" def __repr__(self) -> str: - repr_str = f"LegacySurveySegmentationMap(field_shape={list(self.field.shape)})" + repr_str = f"{self.__class__.__name__}>(field_shape={list(self.field.shape)})" return repr_str @@ -165,56 +165,56 @@ def __repr__(self) -> str: # Flux measurements in different bands -class LegacySurveyFluxG(Scalar, Modality): +class LegacySurveyFluxG(Scalar): """G-band flux measurement from Legacy Survey.""" name: ClassVar[str] = "FLUX_G" token_key: ClassVar[str] = "tok_flux_g" -class LegacySurveyFluxR(Scalar, Modality): +class LegacySurveyFluxR(Scalar): """R-band flux measurement.""" name: ClassVar[str] = "FLUX_R" token_key: ClassVar[str] = "tok_flux_r" -class LegacySurveyFluxI(Scalar, Modality): +class LegacySurveyFluxI(Scalar): """I-band flux measurement.""" name: ClassVar[str] = "FLUX_I" token_key: ClassVar[str] = "tok_flux_i" -class LegacySurveyFluxZ(Scalar, Modality): +class LegacySurveyFluxZ(Scalar): """Z-band flux measurement.""" name: ClassVar[str] = "FLUX_Z" token_key: ClassVar[str] = "tok_flux_z" -class LegacySurveyFluxW1(Scalar, Modality): +class LegacySurveyFluxW1(Scalar): """WISE W1-band flux measurement.""" name: ClassVar[str] = "FLUX_W1" token_key: ClassVar[str] = "tok_flux_w1" -class LegacySurveyFluxW2(Scalar, Modality): +class LegacySurveyFluxW2(Scalar): """WISE W2-band flux measurement.""" name: ClassVar[str] = "FLUX_W2" token_key: ClassVar[str] = "tok_flux_w2" -class LegacySurveyFluxW3(Scalar, Modality): +class LegacySurveyFluxW3(Scalar): """WISE W3-band flux measurement.""" name: ClassVar[str] = "FLUX_W3" token_key: ClassVar[str] = "tok_flux_w3" -class LegacySurveyFluxW4(Scalar, Modality): +class LegacySurveyFluxW4(Scalar): """WISE W4-band flux measurement.""" name: ClassVar[str] = "FLUX_W4" @@ -222,21 +222,21 @@ class LegacySurveyFluxW4(Scalar, Modality): # Shape parameters -class LegacySurveyShapeR(Scalar, Modality): +class LegacySurveyShapeR(Scalar): """R-band shape measurement (e.g., half-light radius).""" name: ClassVar[str] = "SHAPE_R" token_key: ClassVar[str] = "tok_shape_r" -class LegacySurveyShapeE1(Scalar, Modality): +class LegacySurveyShapeE1(Scalar): """First ellipticity component.""" name: ClassVar[str] = "SHAPE_E1" token_key: ClassVar[str] = "tok_shape_e1" -class LegacySurveyShapeE2(Scalar, Modality): +class LegacySurveyShapeE2(Scalar): """Second ellipticity component.""" name: ClassVar[str] = "SHAPE_E2" @@ -244,7 +244,7 @@ class LegacySurveyShapeE2(Scalar, Modality): # Other scalar properties -class LegacySurveyEBV(Scalar, Modality): +class LegacySurveyEBV(Scalar): """E(B-V) extinction measurement.""" name: ClassVar[str] = "EBV" @@ -252,7 +252,7 @@ class LegacySurveyEBV(Scalar, Modality): # Spectroscopic redshift -class Z(Scalar, Modality): +class Z(Scalar): """Spectroscopic redshift measurement.""" name: ClassVar[str] = "Z" @@ -260,91 +260,91 @@ class Z(Scalar, Modality): # Extinction values from HSC -class HSCAG(Scalar, Modality): +class HSCAG(Scalar): """HSC a_g extinction.""" name: ClassVar[str] = "a_g" token_key: ClassVar[str] = "tok_a_g" -class HSCAR(Scalar, Modality): +class HSCAR(Scalar): """HSC a_r extinction.""" name: ClassVar[str] = "a_r" token_key: ClassVar[str] = "tok_a_r" -class HSCAI(Scalar, Modality): +class HSCAI(Scalar): """HSC a_i extinction.""" name: ClassVar[str] = "a_i" token_key: ClassVar[str] = "tok_a_i" -class HSCAZ(Scalar, Modality): +class HSCAZ(Scalar): """HSC a_z extinction.""" name: ClassVar[str] = "a_z" token_key: ClassVar[str] = "tok_a_z" -class HSCAY(Scalar, Modality): +class HSCAY(Scalar): """HSC a_y extinction.""" name: ClassVar[str] = "a_y" token_key: ClassVar[str] = "tok_a_y" -class HSCMagG(Scalar, Modality): +class HSCMagG(Scalar): """HSC g-band cmodel magnitude.""" name: ClassVar[str] = "g_cmodel_mag" token_key: ClassVar[str] = "tok_mag_g" -class HSCMagR(Scalar, Modality): +class HSCMagR(Scalar): """HSC r-band cmodel magnitude.""" name: ClassVar[str] = "r_cmodel_mag" token_key: ClassVar[str] = "tok_mag_r" -class HSCMagI(Scalar, Modality): +class HSCMagI(Scalar): """HSC i-band cmodel magnitude.""" name: ClassVar[str] = "i_cmodel_mag" token_key: ClassVar[str] = "tok_mag_i" -class HSCMagZ(Scalar, Modality): +class HSCMagZ(Scalar): """HSC z-band cmodel magnitude.""" name: ClassVar[str] = "z_cmodel_mag" token_key: ClassVar[str] = "tok_mag_z" -class HSCMagY(Scalar, Modality): +class HSCMagY(Scalar): """HSC y-band cmodel magnitude.""" name: ClassVar[str] = "y_cmodel_mag" token_key: ClassVar[str] = "tok_mag_y" -class HSCShape11(Scalar, Modality): +class HSCShape11(Scalar): """HSC i-band SDSS shape 11 component.""" name: ClassVar[str] = "i_sdssshape_shape11" token_key: ClassVar[str] = "tok_shape11" -class HSCShape22(Scalar, Modality): +class HSCShape22(Scalar): """HSC i-band SDSS shape 22 component.""" name: ClassVar[str] = "i_sdssshape_shape22" token_key: ClassVar[str] = "tok_shape22" -class HSCShape12(Scalar, Modality): +class HSCShape12(Scalar): """HSC i-band SDSS shape 12 component.""" name: ClassVar[str] = "i_sdssshape_shape12" @@ -352,56 +352,56 @@ class HSCShape12(Scalar, Modality): # Gaia modalities -class GaiaFluxG(Scalar, Modality): +class GaiaFluxG(Scalar): """Gaia G-band mean flux.""" name: ClassVar[str] = "phot_g_mean_flux" token_key: ClassVar[str] = "tok_flux_g_gaia" -class GaiaFluxBp(Scalar, Modality): +class GaiaFluxBp(Scalar): """Gaia BP-band mean flux.""" name: ClassVar[str] = "phot_bp_mean_flux" token_key: ClassVar[str] = "tok_flux_bp_gaia" -class GaiaFluxRp(Scalar, Modality): +class GaiaFluxRp(Scalar): """Gaia RP-band mean flux.""" name: ClassVar[str] = "phot_rp_mean_flux" token_key: ClassVar[str] = "tok_flux_rp_gaia" -class GaiaParallax(Scalar, Modality): +class GaiaParallax(Scalar): """Gaia parallax measurement.""" name: ClassVar[str] = "parallax" token_key: ClassVar[str] = "tok_parallax" -class Ra(Scalar, Modality): +class Ra(Scalar): """Right ascension coordinate.""" name: ClassVar[str] = "ra" token_key: ClassVar[str] = "tok_ra" -class Dec(Scalar, Modality): +class Dec(Scalar): """Declination coordinate.""" name: ClassVar[str] = "dec" token_key: ClassVar[str] = "tok_dec" -class GaiaXpBp(Scalar, Modality): +class GaiaXpBp(Scalar): """Gaia BP spectral coefficients.""" name: ClassVar[str] = "bp_coefficients" token_key: ClassVar[str] = "tok_xp_bp" -class GaiaXpRp(Scalar, Modality): +class GaiaXpRp(Scalar): """Gaia RP spectral coefficients.""" name: ClassVar[str] = "rp_coefficients" diff --git a/tests/codecs/test_catalog_codec.py b/tests/codecs/test_catalog_codec.py index e454c17..0180f0c 100644 --- a/tests/codecs/test_catalog_codec.py +++ b/tests/codecs/test_catalog_codec.py @@ -3,9 +3,11 @@ from aion.codecs import CatalogCodec from aion.modalities import LegacySurveyCatalog +from aion.codecs.config import HF_REPO_ID + def test_catalog_tokenizer(data_dir): - codec = CatalogCodec.from_pretrained("polymathic-ai/aion-catalog-codec") + codec = CatalogCodec.from_pretrained(HF_REPO_ID, modality=LegacySurveyCatalog) codec.eval() input_batch = torch.load( data_dir / "catalog_codec_input_batch.pt", weights_only=False diff --git a/tests/codecs/test_image_codec.py b/tests/codecs/test_image_codec.py index 26c40fd..f60d2d1 100644 --- a/tests/codecs/test_image_codec.py +++ b/tests/codecs/test_image_codec.py @@ -1,8 +1,9 @@ import pytest import torch -from aion.modalities import Image from aion.codecs import ImageCodec +from aion.codecs.config import HF_REPO_ID +from aion.modalities import Image @pytest.mark.parametrize("embedding_dim", [5, 10]) @@ -40,7 +41,7 @@ def test_magvit_image_tokenizer( def test_hf_previous_predictions(data_dir): - codec = ImageCodec.from_pretrained("polymathic-ai/aion-image-codec") + codec = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) input_batch_dict = torch.load( data_dir / "image_codec_input_batch.pt", weights_only=False @@ -52,7 +53,6 @@ def test_hf_previous_predictions(data_dir): data_dir / "image_codec_decoded_batch.pt", weights_only=False ) with torch.no_grad(): - print(input_batch_dict["image"]["channel_mask"][0]) input_image_obj = Image( flux=input_batch_dict["image"]["array"][:, 5:], bands=["DES-G", "DES-R", "DES-I", "DES-Z"], diff --git a/tests/codecs/test_scalar_codec.py b/tests/codecs/test_scalar_codec.py index 32a7c70..6eb3b80 100644 --- a/tests/codecs/test_scalar_codec.py +++ b/tests/codecs/test_scalar_codec.py @@ -1,43 +1,44 @@ import pytest import torch -from aion.codecs import LogScalarCodec, ScalarCodec, MultiScalarCodec, GridScalarCodec +from aion.codecs import GridScalarCodec, LogScalarCodec, MultiScalarCodec, ScalarCodec +from aion.codecs.config import HF_REPO_ID from aion.modalities import ( - LegacySurveyFluxG, - LegacySurveyFluxR, - LegacySurveyFluxI, - LegacySurveyFluxZ, - LegacySurveyFluxW1, - LegacySurveyFluxW2, - LegacySurveyFluxW3, - LegacySurveyFluxW4, - LegacySurveyShapeR, - LegacySurveyShapeE1, - LegacySurveyShapeE2, - LegacySurveyEBV, - Z, - HSCMagG, - HSCMagR, - HSCMagI, - HSCMagZ, - HSCMagY, - HSCShape11, - HSCShape22, - HSCShape12, HSCAG, - HSCAR, HSCAI, - HSCAZ, + HSCAR, HSCAY, + HSCAZ, + Dec, + GaiaFluxBp, # Gaia modalities GaiaFluxG, - GaiaFluxBp, GaiaFluxRp, GaiaParallax, - Ra, - Dec, GaiaXpBp, GaiaXpRp, + HSCMagG, + HSCMagI, + HSCMagR, + HSCMagY, + HSCMagZ, + HSCShape11, + HSCShape12, + HSCShape22, + LegacySurveyEBV, + LegacySurveyFluxG, + LegacySurveyFluxI, + LegacySurveyFluxR, + LegacySurveyFluxW1, + LegacySurveyFluxW2, + LegacySurveyFluxW3, + LegacySurveyFluxW4, + LegacySurveyFluxZ, + LegacySurveyShapeE1, + LegacySurveyShapeE2, + LegacySurveyShapeR, + Ra, + Z, ) @@ -87,9 +88,7 @@ ], ) def test_scalar_tokenizer(data_dir, codec_class, modality): - codec = codec_class.from_pretrained( - f"polymathic-ai/aion-scalar-{modality.name.lower().replace('_', '-')}-codec" - ) + codec = codec_class.from_pretrained(HF_REPO_ID, modality=modality) codec.eval() input_batch = torch.load( data_dir / f"{modality.name}_codec_input_batch.pt", weights_only=False diff --git a/tests/codecs/test_scalar_field_codec.py b/tests/codecs/test_scalar_field_codec.py index 50d0077..8c74ba3 100644 --- a/tests/codecs/test_scalar_field_codec.py +++ b/tests/codecs/test_scalar_field_codec.py @@ -1,11 +1,14 @@ import torch from aion.codecs import ScalarFieldCodec +from aion.codecs.config import HF_REPO_ID from aion.modalities import LegacySurveySegmentationMap def test_scalar_field_tokenizer(data_dir): - codec = ScalarFieldCodec.from_pretrained("polymathic-ai/aion-scalar-field-codec") + codec = ScalarFieldCodec.from_pretrained( + HF_REPO_ID, modality=LegacySurveySegmentationMap + ) codec.eval() input_batch = torch.load( data_dir / "scalar-field_codec_input_batch.pt", weights_only=False diff --git a/tests/codecs/test_spectrum_codec.py b/tests/codecs/test_spectrum_codec.py index 6fdbf46..3e0e2a6 100644 --- a/tests/codecs/test_spectrum_codec.py +++ b/tests/codecs/test_spectrum_codec.py @@ -1,11 +1,12 @@ import torch -from aion.modalities import Spectrum from aion.codecs import SpectrumCodec +from aion.codecs.config import HF_REPO_ID +from aion.modalities import Spectrum def test_hf_previous_predictions(data_dir): - codec = SpectrumCodec.from_pretrained("polymathic-ai/aion-spectrum-codec") + codec = SpectrumCodec.from_pretrained(HF_REPO_ID, modality=Spectrum) input_batch = torch.load(data_dir / "SPECTRUM_input_batch.pt", weights_only=False)[ "spectrum"