diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9990fd1 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,89 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +AION (AstronomIcal Omnimodal Network) is a large omnimodal transformer model for astronomical surveys. It processes 39 distinct astronomical data modalities using a two-stage architecture: + +1. **Modality-specific tokenizers** transform raw inputs (images, spectra, catalogs, scalars) into discrete tokens +2. **Unified encoder-decoder transformer** processes all token streams via multimodal masked modeling (4M) + +The model comes in three variants: Base (300M), Large (800M), and XLarge (3B parameters). + +## Development Commands + +### Testing +```bash +pytest # Run all tests +pytest tests/codecs/ # Run codec tests only +pytest tests/test_data/ # Uses pre-computed test data for validation +``` + +### Linting and Code Quality +```bash +ruff check . # Check code style and lint +ruff check . --fix # Auto-fix linting issues +``` + +### Installation for Development +```bash +pip install -e .[torch,dev] # Install in editable mode with dev dependencies +``` + +### Documentation +```bash +cd docs && make html # Build Sphinx documentation +``` + +## Architecture Overview + +### Core Components + +- **`aion/model.py`**: Main AION wrapper class, inherits from FM (4M) transformer +- **`aion/fourm/`**: 4M (Four-Modal) transformer implementation + - `fm.py`: Core transformer architecture with encoder-decoder blocks + - `modality_info.py`: Configuration for all 39 supported modalities + - `encoder_embeddings.py` / `decoder_embeddings.py`: Modality-specific embedding layers +- **`aion/codecs/`**: Modality tokenization system + - `manager.py`: Dynamic codec loading and management + - `base.py`: Abstract base codec class + - Individual codec implementations for images, spectra, scalars, etc. +- **`aion/modalities.py`**: Type definitions for all astronomical data types + +### Key Design Patterns + +1. **Modality System**: Each astronomical data type (flux, spectrum, catalog) has: + - A modality class in `modalities.py` defining data structure + - A codec in `codecs/` for tokenization + - Embedding layers in `fourm/` for the transformer + +2. **Token Keys**: Each modality has a `token_key` (e.g., `tok_image`, `tok_spectrum_sdss`) that maps between modalities and model components + +3. **HuggingFace Integration**: Models and codecs are distributed via HuggingFace Hub with `from_pretrained()` methods + +## Code Conventions + +- Type hints are mandatory, using `jaxtyping` for tensor shapes (e.g., `Float[Tensor, "batch height width"]`) +- Modality classes use `@dataclass` and inherit from `BaseModality` +- All tensor operations should handle device placement explicitly +- Test data is pre-computed and stored in `tests/test_data/` as `.pt` files + +## Testing Strategy + +Tests validate both encoding and decoding for each modality using pre-computed reference data. The test pattern is: +1. Load input, encoded, and decoded reference tensors +2. Run codec encode/decode operations +3. Assert outputs match reference data within tolerance + +Test files follow naming: `test_{modality}_codec.py` + +## Astronomical Context + +The model processes data from major surveys: +- **Legacy Survey**: Optical images and catalogs (g,r,i,z bands + WISE) +- **HSC (Hyper Suprime-Cam)**: Deep optical imaging (g,r,i,z,y bands) +- **Gaia**: Astrometry, photometry, and BP/RP spectra +- **SDSS/DESI**: Optical spectra + +Each modality represents different physical measurements (flux, shape parameters, coordinates, extinction, etc.) that the model learns to correlate. diff --git a/aion/codecs/config.py b/aion/codecs/config.py index bb3734c..8835677 100644 --- a/aion/codecs/config.py +++ b/aion/codecs/config.py @@ -18,12 +18,14 @@ HSCAY, HSCAZ, Dec, + DESISpectrum, GaiaFluxBp, GaiaFluxG, GaiaFluxRp, GaiaParallax, GaiaXpBp, GaiaXpRp, + HSCImage, HSCMagG, HSCMagI, HSCMagR, @@ -43,11 +45,13 @@ LegacySurveyFluxW3, LegacySurveyFluxW4, LegacySurveyFluxZ, + LegacySurveyImage, LegacySurveySegmentationMap, LegacySurveyShapeE1, LegacySurveyShapeE2, LegacySurveyShapeR, Ra, + SDSSSpectrum, Spectrum, Z, ) @@ -76,43 +80,47 @@ class CodecHFConfig: MODALITY_CODEC_MAPPING = { + Dec: ScalarCodec, + DESISpectrum: SpectrumCodec, + GaiaFluxBp: LogScalarCodec, + GaiaFluxG: LogScalarCodec, + GaiaFluxRp: LogScalarCodec, + GaiaParallax: LogScalarCodec, + GaiaXpBp: MultiScalarCodec, + GaiaXpRp: MultiScalarCodec, + HSCAG: ScalarCodec, + HSCAI: ScalarCodec, + HSCAR: ScalarCodec, + HSCAY: ScalarCodec, + HSCAZ: ScalarCodec, + HSCImage: ImageCodec, + HSCMagG: ScalarCodec, + HSCMagI: ScalarCodec, + HSCMagR: ScalarCodec, + HSCMagY: ScalarCodec, + HSCMagZ: ScalarCodec, + HSCShape11: ScalarCodec, + HSCShape12: ScalarCodec, + HSCShape22: ScalarCodec, Image: ImageCodec, - Spectrum: SpectrumCodec, LegacySurveyCatalog: CatalogCodec, - LegacySurveySegmentationMap: ScalarFieldCodec, + LegacySurveyEBV: ScalarCodec, LegacySurveyFluxG: LogScalarCodec, - LegacySurveyFluxR: LogScalarCodec, LegacySurveyFluxI: LogScalarCodec, - LegacySurveyFluxZ: LogScalarCodec, + LegacySurveyFluxR: LogScalarCodec, LegacySurveyFluxW1: LogScalarCodec, LegacySurveyFluxW2: LogScalarCodec, LegacySurveyFluxW3: LogScalarCodec, LegacySurveyFluxW4: LogScalarCodec, - LegacySurveyShapeR: LogScalarCodec, - GaiaFluxG: LogScalarCodec, - GaiaFluxBp: LogScalarCodec, - GaiaFluxRp: LogScalarCodec, - GaiaParallax: LogScalarCodec, + LegacySurveyFluxZ: LogScalarCodec, + LegacySurveyImage: ImageCodec, + LegacySurveySegmentationMap: ScalarFieldCodec, LegacySurveyShapeE1: ScalarCodec, LegacySurveyShapeE2: ScalarCodec, - LegacySurveyEBV: ScalarCodec, - HSCMagG: ScalarCodec, - HSCMagR: ScalarCodec, - HSCMagI: ScalarCodec, - HSCMagZ: ScalarCodec, - HSCMagY: ScalarCodec, - HSCShape11: ScalarCodec, - HSCShape22: ScalarCodec, - HSCShape12: ScalarCodec, - HSCAG: ScalarCodec, - HSCAR: ScalarCodec, - HSCAI: ScalarCodec, - HSCAZ: ScalarCodec, - HSCAY: ScalarCodec, + LegacySurveyShapeR: LogScalarCodec, Ra: ScalarCodec, - Dec: ScalarCodec, - GaiaXpBp: MultiScalarCodec, - GaiaXpRp: MultiScalarCodec, + SDSSSpectrum: SpectrumCodec, + Spectrum: SpectrumCodec, Z: GridScalarCodec, } diff --git a/aion/codecs/manager.py b/aion/codecs/manager.py index d04b363..cd98ba2 100644 --- a/aion/codecs/manager.py +++ b/aion/codecs/manager.py @@ -60,11 +60,6 @@ def _load_codec(self, modality_type: type[Modality]) -> Codec: # Look up configuration in CODEC_CONFIG if modality_type in MODALITY_CODEC_MAPPING: codec_class = MODALITY_CODEC_MAPPING[modality_type] - elif ( - hasattr(modality_type, "__base__") - and modality_type.__base__ in MODALITY_CODEC_MAPPING - ): - codec_class = MODALITY_CODEC_MAPPING[modality_type.__base__] else: raise ModalityTypeError( f"No codec configuration found for modality type: {modality_type.__name__}" diff --git a/aion/codecs/scalar.py b/aion/codecs/scalar.py index e40b235..a256b4e 100644 --- a/aion/codecs/scalar.py +++ b/aion/codecs/scalar.py @@ -64,7 +64,7 @@ def __init__( reservoir_size: int, ): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = ScalarReservoirQuantizer( codebook_size=codebook_size, reservoir_size=reservoir_size, @@ -80,7 +80,7 @@ def __init__( min_log_value: float | None = -3, ): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = ScalarLogReservoirQuantizer( codebook_size=codebook_size, reservoir_size=reservoir_size, @@ -99,7 +99,7 @@ def __init__( num_quantizers: int, ): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = MultiScalarCompressedReservoirQuantizer( compression_fns=compression_fns, decompression_fns=decompression_fns, @@ -112,7 +112,7 @@ def __init__( class GridScalarCodec(BaseScalarIdentityCodec): def __init__(self, modality: str, codebook_size: int): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = ScalarLinearQuantizer( codebook_size=codebook_size, range=(0.0, 1.0), diff --git a/aion/codecs/scalar_field.py b/aion/codecs/scalar_field.py index c736a9b..5499a42 100644 --- a/aion/codecs/scalar_field.py +++ b/aion/codecs/scalar_field.py @@ -1,4 +1,3 @@ -from functools import reduce from typing import Callable, Optional, Type import torch @@ -17,10 +16,6 @@ from .quantizers import FiniteScalarQuantizer, Quantizer -def _deep_get(dictionary, path, default=None): - return reduce(lambda d, key: d[key], path.split("."), dictionary) - - class AutoencoderScalarFieldCodec(Codec): """Abstract class for autoencoding scalar field codecs.""" diff --git a/aion/codecs/utils.py b/aion/codecs/utils.py index 4b0c1a6..6c0f686 100644 --- a/aion/codecs/utils.py +++ b/aion/codecs/utils.py @@ -7,6 +7,7 @@ from aion.codecs.base import Codec from aion.modalities import Modality + ORIGINAL_CONFIG_NAME = hub_mixin.constants.CONFIG_NAME ORIGINAL_PYTORCH_WEIGHTS_NAME = hub_mixin.constants.PYTORCH_WEIGHTS_NAME ORIGINAL_SAFETENSORS_SINGLE_FILE = hub_mixin.constants.SAFETENSORS_SINGLE_FILE @@ -79,6 +80,30 @@ class CodecPytorchHubMixin(hub_mixin.PyTorchModelHubMixin): Instead they lie in the transformer model repo as subfolders. """ + @staticmethod + def _validate_codec_modality(codec: type[Codec], modality: type[Modality]): + """Validate that a codec class is compatible with a modality. + + Args: + codec: The codec class to validate + modality: The modality type to validate against + + Raises: + TypeError: If the codec is not a valid codec class or is incompatible with the modality + ValueError: If the modality has no corresponding codec configuration + """ + # Import MODALITY_CODEC_MAPPING here to avoid circular import + from aion.codecs.config import MODALITY_CODEC_MAPPING + + if not issubclass(codec, Codec): + raise TypeError("Only codecs can be loaded using this method.") + if modality not in MODALITY_CODEC_MAPPING: + raise ValueError(f"Modality {modality} has no corresponding codec.") + elif MODALITY_CODEC_MAPPING[modality] != codec: + raise TypeError( + f"Modality {modality} is associated with {MODALITY_CODEC_MAPPING[modality]} codec but {codec} requested." + ) + @classmethod def from_pretrained( cls, @@ -104,8 +129,8 @@ def from_pretrained( Raises: ValueError: If the class is not a codec subclass or modality is invalid. """ - if not issubclass(cls, Codec): - raise ValueError("Only codec classes can be loaded using this method.") + # Validate codec-modality compatibility + cls._validate_codec_modality(cls, modality) # Validate modality _validate_modality(modality) diff --git a/aion/modalities.py b/aion/modalities.py index 9bfdb44..382b1b5 100644 --- a/aion/modalities.py +++ b/aion/modalities.py @@ -408,42 +408,45 @@ class GaiaXpRp(Scalar): token_key: ClassVar[str] = "tok_xp_rp" -ScalarModalities = [ - LegacySurveyFluxG, - LegacySurveyFluxR, - LegacySurveyFluxI, - LegacySurveyFluxZ, - LegacySurveyFluxW1, - LegacySurveyFluxW2, - LegacySurveyFluxW3, - LegacySurveyFluxW4, - LegacySurveyShapeR, - LegacySurveyShapeE1, - LegacySurveyShapeE2, - LegacySurveyEBV, - Z, - HSCAG, - HSCAR, - HSCAI, - HSCAZ, - HSCAY, - HSCMagG, - HSCMagR, - HSCMagI, - HSCMagZ, - HSCMagY, - HSCShape11, - HSCShape22, - HSCShape12, - GaiaFluxG, - GaiaFluxBp, - GaiaFluxRp, - GaiaParallax, - Ra, - Dec, - GaiaXpBp, - GaiaXpRp, -] +ScalarModalities = { + modality.name: modality + for modality in [ + LegacySurveyFluxG, + LegacySurveyFluxR, + LegacySurveyFluxI, + LegacySurveyFluxZ, + LegacySurveyFluxW1, + LegacySurveyFluxW2, + LegacySurveyFluxW3, + LegacySurveyFluxW4, + LegacySurveyShapeR, + LegacySurveyShapeE1, + LegacySurveyShapeE2, + LegacySurveyEBV, + Z, + HSCAG, + HSCAR, + HSCAI, + HSCAZ, + HSCAY, + HSCMagG, + HSCMagR, + HSCMagI, + HSCMagZ, + HSCMagY, + HSCShape11, + HSCShape22, + HSCShape12, + GaiaFluxG, + GaiaFluxBp, + GaiaFluxRp, + GaiaParallax, + Ra, + Dec, + GaiaXpBp, + GaiaXpRp, + ] +} # Convenience type for any modality data ModalityType = ( diff --git a/tests/codecs/test_load_codecs.py b/tests/codecs/test_load_codecs.py new file mode 100644 index 0000000..37c7bff --- /dev/null +++ b/tests/codecs/test_load_codecs.py @@ -0,0 +1,24 @@ +import pytest +import torch + +from aion.codecs import ImageCodec +from aion.codecs.config import HF_REPO_ID +from aion.modalities import Image, LegacySurveyCatalog, LegacySurveyImage + + +def test_load_invalid_modality(): + """Test that loading a modality raises an error.""" + with pytest.raises(TypeError): + ImageCodec.from_pretrained(HF_REPO_ID, modality=LegacySurveyCatalog) + + +def test_load_image_codec(): + """Test that loading an image codec raises an error.""" + codec_image = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) + codec_legacy_survey_image = ImageCodec.from_pretrained( + HF_REPO_ID, modality=LegacySurveyImage + ) + for param_image, param_legacy_survey_image in zip( + codec_image.parameters(), codec_legacy_survey_image.parameters() + ): + assert torch.equal(param_image, param_legacy_survey_image)