Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Overview

AION (AstronomIcal Omnimodal Network) is a large omnimodal transformer model for astronomical surveys. It processes 39 distinct astronomical data modalities using a two-stage architecture:

1. **Modality-specific tokenizers** transform raw inputs (images, spectra, catalogs, scalars) into discrete tokens
2. **Unified encoder-decoder transformer** processes all token streams via multimodal masked modeling (4M)

The model comes in three variants: Base (300M), Large (800M), and XLarge (3B parameters).

## Development Commands

### Testing
```bash
pytest # Run all tests
pytest tests/codecs/ # Run codec tests only
pytest tests/test_data/ # Uses pre-computed test data for validation
```

### Linting and Code Quality
```bash
ruff check . # Check code style and lint
ruff check . --fix # Auto-fix linting issues
```

### Installation for Development
```bash
pip install -e .[torch,dev] # Install in editable mode with dev dependencies
```

### Documentation
```bash
cd docs && make html # Build Sphinx documentation
```

## Architecture Overview

### Core Components

- **`aion/model.py`**: Main AION wrapper class, inherits from FM (4M) transformer
- **`aion/fourm/`**: 4M (Four-Modal) transformer implementation
- `fm.py`: Core transformer architecture with encoder-decoder blocks
- `modality_info.py`: Configuration for all 39 supported modalities
- `encoder_embeddings.py` / `decoder_embeddings.py`: Modality-specific embedding layers
- **`aion/codecs/`**: Modality tokenization system
- `manager.py`: Dynamic codec loading and management
- `base.py`: Abstract base codec class
- Individual codec implementations for images, spectra, scalars, etc.
- **`aion/modalities.py`**: Type definitions for all astronomical data types

### Key Design Patterns

1. **Modality System**: Each astronomical data type (flux, spectrum, catalog) has:
- A modality class in `modalities.py` defining data structure
- A codec in `codecs/` for tokenization
- Embedding layers in `fourm/` for the transformer

2. **Token Keys**: Each modality has a `token_key` (e.g., `tok_image`, `tok_spectrum_sdss`) that maps between modalities and model components

3. **HuggingFace Integration**: Models and codecs are distributed via HuggingFace Hub with `from_pretrained()` methods

## Code Conventions

- Type hints are mandatory, using `jaxtyping` for tensor shapes (e.g., `Float[Tensor, "batch height width"]`)
- Modality classes use `@dataclass` and inherit from `BaseModality`
- All tensor operations should handle device placement explicitly
- Test data is pre-computed and stored in `tests/test_data/` as `.pt` files

## Testing Strategy

Tests validate both encoding and decoding for each modality using pre-computed reference data. The test pattern is:
1. Load input, encoded, and decoded reference tensors
2. Run codec encode/decode operations
3. Assert outputs match reference data within tolerance

Test files follow naming: `test_{modality}_codec.py`

## Astronomical Context

The model processes data from major surveys:
- **Legacy Survey**: Optical images and catalogs (g,r,i,z bands + WISE)
- **HSC (Hyper Suprime-Cam)**: Deep optical imaging (g,r,i,z,y bands)
- **Gaia**: Astrometry, photometry, and BP/RP spectra
- **SDSS/DESI**: Optical spectra

Each modality represents different physical measurements (flux, shape parameters, coordinates, extinction, etc.) that the model learns to correlate.
60 changes: 34 additions & 26 deletions aion/codecs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
HSCAY,
HSCAZ,
Dec,
DESISpectrum,
GaiaFluxBp,
GaiaFluxG,
GaiaFluxRp,
GaiaParallax,
GaiaXpBp,
GaiaXpRp,
HSCImage,
HSCMagG,
HSCMagI,
HSCMagR,
Expand All @@ -43,11 +45,13 @@
LegacySurveyFluxW3,
LegacySurveyFluxW4,
LegacySurveyFluxZ,
LegacySurveyImage,
LegacySurveySegmentationMap,
LegacySurveyShapeE1,
LegacySurveyShapeE2,
LegacySurveyShapeR,
Ra,
SDSSSpectrum,
Spectrum,
Z,
)
Expand Down Expand Up @@ -76,43 +80,47 @@ class CodecHFConfig:


MODALITY_CODEC_MAPPING = {
Dec: ScalarCodec,
DESISpectrum: SpectrumCodec,
GaiaFluxBp: LogScalarCodec,
GaiaFluxG: LogScalarCodec,
GaiaFluxRp: LogScalarCodec,
GaiaParallax: LogScalarCodec,
GaiaXpBp: MultiScalarCodec,
GaiaXpRp: MultiScalarCodec,
HSCAG: ScalarCodec,
HSCAI: ScalarCodec,
HSCAR: ScalarCodec,
HSCAY: ScalarCodec,
HSCAZ: ScalarCodec,
HSCImage: ImageCodec,
HSCMagG: ScalarCodec,
HSCMagI: ScalarCodec,
HSCMagR: ScalarCodec,
HSCMagY: ScalarCodec,
HSCMagZ: ScalarCodec,
HSCShape11: ScalarCodec,
HSCShape12: ScalarCodec,
HSCShape22: ScalarCodec,
Image: ImageCodec,
Spectrum: SpectrumCodec,
LegacySurveyCatalog: CatalogCodec,
LegacySurveySegmentationMap: ScalarFieldCodec,
LegacySurveyEBV: ScalarCodec,
LegacySurveyFluxG: LogScalarCodec,
LegacySurveyFluxR: LogScalarCodec,
LegacySurveyFluxI: LogScalarCodec,
LegacySurveyFluxZ: LogScalarCodec,
LegacySurveyFluxR: LogScalarCodec,
LegacySurveyFluxW1: LogScalarCodec,
LegacySurveyFluxW2: LogScalarCodec,
LegacySurveyFluxW3: LogScalarCodec,
LegacySurveyFluxW4: LogScalarCodec,
LegacySurveyShapeR: LogScalarCodec,
GaiaFluxG: LogScalarCodec,
GaiaFluxBp: LogScalarCodec,
GaiaFluxRp: LogScalarCodec,
GaiaParallax: LogScalarCodec,
LegacySurveyFluxZ: LogScalarCodec,
LegacySurveyImage: ImageCodec,
LegacySurveySegmentationMap: ScalarFieldCodec,
LegacySurveyShapeE1: ScalarCodec,
LegacySurveyShapeE2: ScalarCodec,
LegacySurveyEBV: ScalarCodec,
HSCMagG: ScalarCodec,
HSCMagR: ScalarCodec,
HSCMagI: ScalarCodec,
HSCMagZ: ScalarCodec,
HSCMagY: ScalarCodec,
HSCShape11: ScalarCodec,
HSCShape22: ScalarCodec,
HSCShape12: ScalarCodec,
HSCAG: ScalarCodec,
HSCAR: ScalarCodec,
HSCAI: ScalarCodec,
HSCAZ: ScalarCodec,
HSCAY: ScalarCodec,
LegacySurveyShapeR: LogScalarCodec,
Ra: ScalarCodec,
Dec: ScalarCodec,
GaiaXpBp: MultiScalarCodec,
GaiaXpRp: MultiScalarCodec,
SDSSSpectrum: SpectrumCodec,
Spectrum: SpectrumCodec,
Z: GridScalarCodec,
}

Expand Down
5 changes: 0 additions & 5 deletions aion/codecs/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def _load_codec(self, modality_type: type[Modality]) -> Codec:
# Look up configuration in CODEC_CONFIG
if modality_type in MODALITY_CODEC_MAPPING:
codec_class = MODALITY_CODEC_MAPPING[modality_type]
elif (
hasattr(modality_type, "__base__")
and modality_type.__base__ in MODALITY_CODEC_MAPPING
):
codec_class = MODALITY_CODEC_MAPPING[modality_type.__base__]
else:
raise ModalityTypeError(
f"No codec configuration found for modality type: {modality_type.__name__}"
Expand Down
8 changes: 4 additions & 4 deletions aion/codecs/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
reservoir_size: int,
):
super().__init__()
self._modality_class = next(m for m in ScalarModalities if m.name == modality)
self._modality_class = ScalarModalities[modality]
self._quantizer = ScalarReservoirQuantizer(
codebook_size=codebook_size,
reservoir_size=reservoir_size,
Expand All @@ -80,7 +80,7 @@ def __init__(
min_log_value: float | None = -3,
):
super().__init__()
self._modality_class = next(m for m in ScalarModalities if m.name == modality)
self._modality_class = ScalarModalities[modality]
self._quantizer = ScalarLogReservoirQuantizer(
codebook_size=codebook_size,
reservoir_size=reservoir_size,
Expand All @@ -99,7 +99,7 @@ def __init__(
num_quantizers: int,
):
super().__init__()
self._modality_class = next(m for m in ScalarModalities if m.name == modality)
self._modality_class = ScalarModalities[modality]
self._quantizer = MultiScalarCompressedReservoirQuantizer(
compression_fns=compression_fns,
decompression_fns=decompression_fns,
Expand All @@ -112,7 +112,7 @@ def __init__(
class GridScalarCodec(BaseScalarIdentityCodec):
def __init__(self, modality: str, codebook_size: int):
super().__init__()
self._modality_class = next(m for m in ScalarModalities if m.name == modality)
self._modality_class = ScalarModalities[modality]
self._quantizer = ScalarLinearQuantizer(
codebook_size=codebook_size,
range=(0.0, 1.0),
Expand Down
5 changes: 0 additions & 5 deletions aion/codecs/scalar_field.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import reduce
from typing import Callable, Optional, Type

import torch
Expand All @@ -17,10 +16,6 @@
from .quantizers import FiniteScalarQuantizer, Quantizer


def _deep_get(dictionary, path, default=None):
return reduce(lambda d, key: d[key], path.split("."), dictionary)


class AutoencoderScalarFieldCodec(Codec):
"""Abstract class for autoencoding scalar field codecs."""

Expand Down
29 changes: 27 additions & 2 deletions aion/codecs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aion.codecs.base import Codec
from aion.modalities import Modality


ORIGINAL_CONFIG_NAME = hub_mixin.constants.CONFIG_NAME
ORIGINAL_PYTORCH_WEIGHTS_NAME = hub_mixin.constants.PYTORCH_WEIGHTS_NAME
ORIGINAL_SAFETENSORS_SINGLE_FILE = hub_mixin.constants.SAFETENSORS_SINGLE_FILE
Expand Down Expand Up @@ -79,6 +80,30 @@ class CodecPytorchHubMixin(hub_mixin.PyTorchModelHubMixin):
Instead they lie in the transformer model repo as subfolders.
"""

@staticmethod
def _validate_codec_modality(codec: type[Codec], modality: type[Modality]):
"""Validate that a codec class is compatible with a modality.

Args:
codec: The codec class to validate
modality: The modality type to validate against

Raises:
TypeError: If the codec is not a valid codec class or is incompatible with the modality
ValueError: If the modality has no corresponding codec configuration
"""
# Import MODALITY_CODEC_MAPPING here to avoid circular import
from aion.codecs.config import MODALITY_CODEC_MAPPING

if not issubclass(codec, Codec):
raise TypeError("Only codecs can be loaded using this method.")
if modality not in MODALITY_CODEC_MAPPING:
raise ValueError(f"Modality {modality} has no corresponding codec.")
elif MODALITY_CODEC_MAPPING[modality] != codec:
raise TypeError(
f"Modality {modality} is associated with {MODALITY_CODEC_MAPPING[modality]} codec but {codec} requested."
)

@classmethod
def from_pretrained(
cls,
Expand All @@ -104,8 +129,8 @@ def from_pretrained(
Raises:
ValueError: If the class is not a codec subclass or modality is invalid.
"""
if not issubclass(cls, Codec):
raise ValueError("Only codec classes can be loaded using this method.")
# Validate codec-modality compatibility
cls._validate_codec_modality(cls, modality)

# Validate modality
_validate_modality(modality)
Expand Down
75 changes: 39 additions & 36 deletions aion/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,42 +408,45 @@ class GaiaXpRp(Scalar):
token_key: ClassVar[str] = "tok_xp_rp"


ScalarModalities = [
LegacySurveyFluxG,
LegacySurveyFluxR,
LegacySurveyFluxI,
LegacySurveyFluxZ,
LegacySurveyFluxW1,
LegacySurveyFluxW2,
LegacySurveyFluxW3,
LegacySurveyFluxW4,
LegacySurveyShapeR,
LegacySurveyShapeE1,
LegacySurveyShapeE2,
LegacySurveyEBV,
Z,
HSCAG,
HSCAR,
HSCAI,
HSCAZ,
HSCAY,
HSCMagG,
HSCMagR,
HSCMagI,
HSCMagZ,
HSCMagY,
HSCShape11,
HSCShape22,
HSCShape12,
GaiaFluxG,
GaiaFluxBp,
GaiaFluxRp,
GaiaParallax,
Ra,
Dec,
GaiaXpBp,
GaiaXpRp,
]
ScalarModalities = {
modality.name: modality
for modality in [
LegacySurveyFluxG,
LegacySurveyFluxR,
LegacySurveyFluxI,
LegacySurveyFluxZ,
LegacySurveyFluxW1,
LegacySurveyFluxW2,
LegacySurveyFluxW3,
LegacySurveyFluxW4,
LegacySurveyShapeR,
LegacySurveyShapeE1,
LegacySurveyShapeE2,
LegacySurveyEBV,
Z,
HSCAG,
HSCAR,
HSCAI,
HSCAZ,
HSCAY,
HSCMagG,
HSCMagR,
HSCMagI,
HSCMagZ,
HSCMagY,
HSCShape11,
HSCShape22,
HSCShape12,
GaiaFluxG,
GaiaFluxBp,
GaiaFluxRp,
GaiaParallax,
Ra,
Dec,
GaiaXpBp,
GaiaXpRp,
]
}

# Convenience type for any modality data
ModalityType = (
Expand Down
Loading
Loading