diff --git a/docs/api.md b/docs/api.md
index 8382838..f861f32 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -1,691 +1,363 @@
# API Reference
-This comprehensive API reference covers all major components of AION-1, including modalities, codecs, models, and utilities.
+This API reference covers the core components you'll actually use with AION-1, based on the working implementation.
## Core Model
### `aion.AION`
-The main AION model class that provides high-level interfaces for multimodal astronomical analysis.
+The main AION model class that provides multimodal astronomical analysis.
```python
-class AION(FourM):
+from aion import AION
+
+class AION(FM):
"""
AION-1 multimodal astronomical foundation model.
- Inherits from FourM architecture and adds astronomical-specific
- functionality for processing 39 different data modalities.
+ Inherits from FM (4M) architecture and adds astronomical-specific
+ functionality for processing multiple data modalities.
"""
@classmethod
- def from_pretrained(
- cls,
- model_name: str,
- device: str = 'cuda',
- torch_dtype: torch.dtype = torch.float32,
- **kwargs
- ) -> 'AION':
+ def from_pretrained(cls, model_name: str, **kwargs) -> 'AION':
"""
- Load a pre-trained AION model.
+ Load a pre-trained AION model from HuggingFace Hub.
Args:
model_name: HuggingFace model identifier
- - 'polymathic-ai/aion-tiny': 300M parameter model
- - 'polymathic-ai/aion-base': 800M parameter model
- - 'polymathic-ai/aion-large': 3.1B parameter model
- device: Device to load model on ('cuda', 'cpu', 'mps')
- torch_dtype: Data type for model weights
- **kwargs: Additional arguments passed to model constructor
+ - 'polymathic-ai/aion-base': 300M parameter model
Returns:
AION model instance
+
+ Example:
+ >>> model = AION.from_pretrained('polymathic-ai/aion-base')
+ >>> model = model.to('cuda').eval()
"""
- def generate(
+ def forward(
self,
- inputs: Dict[str, Modality],
- targets: List[str],
- num_generations: int = 1,
- temperature: float = 1.0,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None
- ) -> Dict[str, Modality]:
+ input_tokens: Dict[str, torch.Tensor],
+ target_mask: Optional[Dict[str, torch.Tensor]] = None,
+ num_encoder_tokens: int = 600,
+ **kwargs
+ ) -> Dict[str, torch.Tensor]:
"""
- Generate target modalities from input observations.
-
- Note:
- ``targets`` must be chosen from the list returned by
- ``AION.supported_targets`` (essentially the 39 modality names
- listed in the architecture documentation). Supplying an
- unsupported string will raise ``ValueError``.
+ Forward pass through the model.
Args:
- inputs: Dictionary mapping modality names to data
- targets: List of modality names to generate
- num_generations: Number of samples to generate
- temperature: Sampling temperature (higher = more diverse)
- top_k: Top-k sampling parameter
- top_p: Nucleus sampling parameter
+ input_tokens: Dictionary mapping modality token keys to token tensors
+ target_mask: Dictionary specifying which tokens to predict
+ Format: {"tok_z": torch.zeros(batch_size, num_target_tokens)}
+ num_encoder_tokens: Number of tokens to use in encoder
Returns:
- Dictionary mapping target names to generated modalities
+ Dictionary mapping target keys to prediction logits
+
+ Example:
+ >>> predictions = model(
+ ... tokens,
+ ... target_mask={"tok_z": torch.zeros(32, 1)},
+ ... num_encoder_tokens=600
+ ... )
+ >>> redshift_probs = torch.softmax(predictions["tok_z"], dim=-1)
"""
def encode(
self,
- inputs: Dict[str, torch.Tensor]
+ input_tokens: Dict[str, torch.Tensor],
+ num_encoder_tokens: int = 600
) -> torch.Tensor:
"""
- Encode input tokens to learned representations.
+ Extract embeddings from input tokens.
Args:
- inputs: Tokenized inputs
+ input_tokens: Dictionary of tokenized modality data
+ num_encoder_tokens: Number of tokens for encoder processing
Returns:
- Encoder hidden states [batch, seq_len, hidden_dim]
- """
+ Encoder embeddings with shape [batch, seq_len, hidden_dim]
- def tokenize(
- self,
- modalities: Dict[str, Modality]
- ) -> Dict[str, torch.Tensor]:
+ Example:
+ >>> embeddings = model.encode(tokens, num_encoder_tokens=600)
+ >>> # Use embeddings for downstream tasks
+ >>> pooled = embeddings.mean(dim=1) # [batch, hidden_dim]
"""
- Convert modalities to discrete tokens using codecs.
-
- Args:
- modalities: Dictionary of modality data
-
- Returns:
- Dictionary of tokenized tensors
- """
-```
-
-## Modalities
-
-AION-1 supports 39 different astronomical data modalities. Each modality is represented by a Pydantic model ensuring type safety and validation.
-
-### Image Modalities
-
-#### `aion.modalities.Image`
-
-```python
-class Image(Modality):
- """
- Multi-band astronomical image.
-
- Attributes:
- flux: Image data array [bands, height, width]
- bands: List of band identifiers (e.g., ['HSC-G', 'HSC-R'])
- ivar: Optional inverse variance array for weighting
- mask: Optional boolean mask array
- """
-
- flux: np.ndarray
- bands: List[str]
- ivar: Optional[np.ndarray] = None
- mask: Optional[np.ndarray] = None
-
- @classmethod
- def batch(cls, images: List['Image']) -> 'Image':
- """Batch multiple images together."""
-
- def crop(self, size: int = 96) -> 'Image':
- """Center crop image to specified size."""
-```
-
-### Spectrum Modalities
-
-#### `aion.modalities.Spectrum`
-
-```python
-class Spectrum(Modality):
- """
- Astronomical spectrum.
-
- Attributes:
- wavelength: Wavelength array in Angstroms
- flux: Flux density array
- ivar: Optional inverse variance
- survey: Source survey identifier
- """
-
- wavelength: np.ndarray
- flux: np.ndarray
- ivar: Optional[np.ndarray] = None
- survey: Optional[str] = None
-
- def resample(
- self,
- new_wavelength: np.ndarray
- ) -> 'Spectrum':
- """Resample spectrum to new wavelength grid."""
-
- def normalize(self) -> 'Spectrum':
- """Apply median normalization."""
-```
-
-### Scalar Modalities
-
-AION-1 includes numerous scalar modalities for photometry, shapes, and physical parameters:
-
-#### Photometric Fluxes
-
-```python
-class FluxG(ScalarModality):
- """g-band flux measurement."""
- value: np.ndarray
- error: Optional[np.ndarray] = None
-
-class FluxR(ScalarModality):
- """r-band flux measurement."""
- value: np.ndarray
- error: Optional[np.ndarray] = None
-
-class FluxI(ScalarModality):
- """i-band flux measurement."""
- value: np.ndarray
- error: Optional[np.ndarray] = None
-
-class FluxZ(ScalarModality):
- """z-band flux measurement."""
- value: np.ndarray
- error: Optional[np.ndarray] = None
```
-#### Shape Parameters
+## Codec Management
-```python
-class E1(ScalarModality):
- """First ellipticity component."""
- value: np.ndarray
-
-class E2(ScalarModality):
- """Second ellipticity component."""
- value: np.ndarray
-
-class RadiusCARP(ScalarModality):
- """CARP radius measurement."""
- value: np.ndarray
-```
+### `aion.codecs.CodecManager`
-#### Physical Properties
+Manages automatic loading and application of modality-specific codecs.
```python
-class Redshift(ScalarModality):
- """Spectroscopic or photometric redshift."""
- value: np.ndarray
- error: Optional[np.ndarray] = None
-
-class ExtinctionV(ScalarModality):
- """V-band extinction."""
- value: np.ndarray
-
-class Parallax(ScalarModality):
- """Parallax measurement in mas."""
- value: np.ndarray
- error: Optional[np.ndarray] = None
-```
-
-### Catalog Modalities
+from aion.codecs import CodecManager
-#### `aion.modalities.Catalog`
-
-```python
-class Catalog(Modality):
+class CodecManager:
"""
- Astronomical object catalog.
-
- Attributes:
- entries: List of catalog objects
- max_objects: Maximum number of objects to process
+ Central manager for encoding/decoding between modalities and tokens.
"""
- entries: List[CatalogEntry]
- max_objects: int = 100
-
- def sort_by_distance(self) -> 'Catalog':
- """Sort entries by distance from center."""
-
- def filter_bright(self, magnitude_limit: float) -> 'Catalog':
- """Filter to objects brighter than limit."""
-```
-
-## Codecs (Tokenizers)
-
-Codecs convert between modalities and discrete tokens. Each modality type has a specialized codec.
-
-### Base Codec Interface
-
-#### `aion.codecs.base.Codec`
-
-```python
-class Codec(ABC):
- """
- Abstract base class for modality codecs.
- """
-
- @abstractmethod
- def encode(self, modality: Modality) -> torch.Tensor:
- """Encode modality to discrete tokens."""
-
- @abstractmethod
- def decode(self, tokens: torch.Tensor) -> Modality:
- """Decode tokens back to modality."""
-
- @classmethod
- def from_pretrained(cls, path: str) -> 'Codec':
- """Load pre-trained codec."""
-
- def save_pretrained(self, path: str):
- """Save codec weights and configuration."""
-```
-
-### Image Codec
-
-#### `aion.codecs.ImageCodec`
-
-```python
-class ImageCodec(Codec):
- """
- Image tokenizer using MagVit architecture.
-
- Supports multi-survey images with different band counts
- through a unified channel embedding scheme.
- """
-
- def __init__(
- self,
- hidden_dim: int = 512,
- n_embed: int = 10000,
- compression_levels: int = 2,
- quantizer: str = 'fsq'
- ):
+ def __init__(self, device: str = 'cuda'):
"""
- Initialize image codec.
+ Initialize codec manager.
Args:
- hidden_dim: Hidden dimension size
- n_embed: Codebook size
- compression_levels: Spatial compression factor
- quantizer: Quantization method ('fsq' or 'vq')
- """
-
- def preprocess(
- self,
- image: Image,
- crop_size: int = 96
- ) -> torch.Tensor:
- """Apply survey-specific preprocessing."""
-
- def get_latent_shape(
- self,
- image_shape: Tuple[int, ...]
- ) -> Tuple[int, ...]:
- """Get shape of latent representation."""
-```
+ device: Device to load codecs on ('cuda', 'cpu')
-### Spectrum Codec
-
-#### `aion.codecs.SpectrumCodec`
-
-```python
-class SpectrumCodec(Codec):
- """
- Spectrum tokenizer using ConvNeXt V2 architecture.
-
- Uses a shared latent wavelength grid to handle spectra
- from different instruments.
- """
-
- def __init__(
- self,
- latent_wavelength: np.ndarray,
- hidden_dims: List[int] = [96, 192, 384, 768],
- n_embed: int = 1024,
- quantizer: str = 'lfq'
- ):
+ Example:
+ >>> codec_manager = CodecManager(device='cuda')
"""
- Initialize spectrum codec.
- Args:
- latent_wavelength: Target wavelength grid
- hidden_dims: ConvNeXt stage dimensions
- n_embed: Codebook size
- quantizer: Quantization method
+ def encode(self, *modalities) -> Dict[str, torch.Tensor]:
"""
+ Encode modalities into discrete tokens.
- def to_latent_grid(
- self,
- spectrum: Spectrum
- ) -> torch.Tensor:
- """Interpolate spectrum to latent wavelength grid."""
-```
-
-### Scalar Codec
+ Args:
+ *modalities: Variable number of modality objects
-#### `aion.codecs.ScalarCodec`
+ Returns:
+ Dictionary mapping token keys to token tensors
-```python
-class ScalarCodec(Codec):
- """
- Tokenizer for scalar quantities using adaptive quantization.
- """
+ Example:
+ >>> tokens = codec_manager.encode(image, spectrum, flux_g)
+ >>> # Returns: {"tok_image": tensor(...), "tok_spectrum_sdss": tensor(...), "tok_flux_g": tensor(...)}
+ """
- def __init__(
+ def decode(
self,
- quantizer_type: str = 'reservoir',
- n_bins: int = 256
+ tokens: Dict[str, torch.Tensor],
+ modality_class: type,
+ **metadata
):
"""
- Initialize scalar codec.
+ Decode tokens back to modality objects.
Args:
- quantizer_type: Type of quantizer
- - 'linear': Uniform bins
- - 'log': Logarithmic bins
- - 'reservoir': Learned adaptive bins
- - 'compressed': Transform then quantize
- n_bins: Number of quantization levels
- """
+ tokens: Dictionary of token tensors
+ modality_class: Class of modality to decode (e.g., LegacySurveyImage)
+ **metadata: Additional metadata required for reconstruction
- def fit(self, values: np.ndarray):
- """Fit quantizer to data distribution."""
+ Returns:
+ Reconstructed modality object
+
+ Example:
+ >>> reconstructed = codec_manager.decode(
+ ... tokens,
+ ... LegacySurveyImage,
+ ... bands=["DES-G", "DES-R", "DES-I", "DES-Z"]
+ ... )
+ """
```
-## Quantizers
+## Modalities
-Quantization modules that convert continuous values to discrete tokens.
+AION-1 uses a typed modality system to ensure data compatibility and provenance tracking.
-### `aion.codecs.quantizers.FSQ`
+### Base Classes
```python
-class FiniteScalarQuantization(nn.Module):
- """
- Finite Scalar Quantization from MagVit.
+from aion.modalities import BaseModality
- Factorizes codebook into multiple small codebooks for
- better gradient flow and training stability.
- """
+class BaseModality:
+ """Base class for all astronomical modalities."""
- def __init__(
- self,
- levels: List[int] = [8, 5, 5, 5, 5],
- eps: float = 1e-3
- ):
- """
- Args:
- levels: Number of levels per dimension
- eps: Small constant for numerical stability
- """
+ @property
+ def token_key(self) -> str:
+ """Unique identifier for this modality type in the model."""
```
-### `aion.codecs.quantizers.LFQ`
+### Image Modalities
```python
-class LookupFreeQuantization(nn.Module):
- """
- Lookup-Free Quantization using entropy regularization.
+from aion.modalities import LegacySurveyImage, HSCImage
- Achieves quantization without explicit codebook lookup,
- improving training efficiency.
+class LegacySurveyImage(BaseModality):
"""
+ Legacy Survey multi-band image.
- def __init__(
- self,
- dim: int,
- codebook_size: int,
- entropy_weight: float = 0.1
- ):
- """
- Args:
- dim: Embedding dimension
- codebook_size: Target vocabulary size
- entropy_weight: Entropy regularization weight
- """
-```
-
-## Preprocessing
-
-Survey-specific preprocessing utilities.
-
-### `aion.codecs.preprocessing.ImagePreprocessor`
-
-```python
-class ImagePreprocessor:
- """
- Survey-specific image preprocessing.
+ Attributes:
+ flux: Image tensor with shape [batch, 4, height, width] for g,r,i,z bands
+ bands: List of band identifiers (e.g., ['DES-G', 'DES-R', 'DES-I', 'DES-Z'])
"""
- def __init__(self, survey: str):
- """
- Initialize for specific survey.
-
- Args:
- survey: Survey name ('HSC', 'DES', 'SDSS', etc.)
- """
-
- def __call__(self, image: Image) -> torch.Tensor:
- """Apply preprocessing pipeline."""
-
- def get_rescaling_params(self) -> Dict[str, float]:
- """Get survey-specific rescaling parameters."""
-```
+ flux: torch.Tensor
+ bands: List[str]
-### `aion.codecs.preprocessing.SpectrumPreprocessor`
+ @property
+ def token_key(self) -> str:
+ return "tok_image"
-```python
-class SpectrumPreprocessor:
- """
- Spectrum normalization and preprocessing.
+class HSCImage(BaseModality):
"""
+ HSC multi-band image.
- def normalize_median(
- self,
- spectrum: Spectrum
- ) -> Spectrum:
- """Apply median normalization."""
-
- def mask_skylines(
- self,
- spectrum: Spectrum
- ) -> Spectrum:
- """Mask common sky emission lines."""
-```
-
-## Model Components
-
-### `aion.fourm.FourM`
-
-```python
-class FourM(nn.Module):
+ Attributes:
+ flux: Image tensor with shape [batch, 5, height, width] for g,r,i,z,y bands
+ bands: List of band identifiers
"""
- Base multimodal transformer architecture.
- Implements the encoder-decoder architecture with
- modality-specific embeddings and flexible attention.
- """
+ flux: torch.Tensor
+ bands: List[str]
- def __init__(
- self,
- encoder_depth: int = 12,
- decoder_depth: int = 12,
- dim: int = 768,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- use_bias: bool = False
- ):
- """Initialize FourM architecture."""
+ @property
+ def token_key(self) -> str:
+ return "tok_image"
```
-### `aion.fourm.encoder_embeddings`
+### Spectrum Modalities
```python
-class ModalityEmbedding(nn.Module):
- """
- Learnable embeddings for each modality type.
+from aion.modalities import DESISpectrum, SDSSSpectrum
- Provides both modality identification and survey
- provenance information.
+class DESISpectrum(BaseModality):
"""
+ DESI spectroscopic observation.
- def __init__(
- self,
- num_modalities: int,
- num_surveys: int,
- embed_dim: int
- ):
- """Initialize modality embeddings."""
-```
-
-## Utilities
+ Attributes:
+ flux: Flux density array
+ ivar: Inverse variance array
+ mask: Boolean mask array
+ wavelength: Wavelength array in Angstroms
+ """
-### `aion.model_utils`
+ flux: torch.Tensor
+ ivar: torch.Tensor
+ mask: torch.Tensor
+ wavelength: torch.Tensor
-```python
-def load_codec(modality: str, device: str = 'cuda') -> Codec:
- """Load pre-trained codec for modality."""
+ @property
+ def token_key(self) -> str:
+ return "tok_spectrum_desi"
-def create_model_config(
- model_size: str = 'base'
-) -> Dict[str, Any]:
- """Get configuration for model size."""
+class SDSSSpectrum(BaseModality):
+ """SDSS spectroscopic observation."""
-def count_parameters(model: nn.Module) -> int:
- """Count trainable parameters in model."""
+ @property
+ def token_key(self) -> str:
+ return "tok_spectrum_sdss"
```
-### `aion.generation_utils`
+### Scalar Modalities
```python
-def sample_with_temperature(
- logits: torch.Tensor,
- temperature: float = 1.0,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None
-) -> torch.Tensor:
- """
- Sample from logits with temperature scaling.
-
- Args:
- logits: Model output logits
- temperature: Sampling temperature
- top_k: Top-k filtering
- top_p: Nucleus sampling threshold
-
- Returns:
- Sampled token indices
- """
+from aion.modalities import (
+ LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ,
+ Z, GaiaParallax
+)
-def generate_with_caching(
- model: AION,
- inputs: Dict[str, torch.Tensor],
- max_length: int,
- use_cache: bool = True
-) -> torch.Tensor:
- """Generate tokens with KV caching for efficiency."""
-```
+class LegacySurveyFluxG(BaseModality):
+ """Legacy Survey g-band flux measurement."""
-## Data Loading
+ value: torch.Tensor
-### `aion.data.AstronomicalDataset`
+ @property
+ def token_key(self) -> str:
+ return "tok_flux_g"
-```python
-class AstronomicalDataset(Dataset):
- """
- PyTorch dataset for astronomical observations.
- """
+class Z(BaseModality):
+ """Spectroscopic redshift."""
- def __init__(
- self,
- data_paths: List[str],
- modalities: List[str],
- transform: Optional[Callable] = None
- ):
- """
- Initialize dataset.
-
- Args:
- data_paths: Paths to data files
- modalities: List of modalities to load
- transform: Optional data transformation
- """
+ value: torch.Tensor
- def __getitem__(self, idx: int) -> Dict[str, Modality]:
- """Get single observation."""
+ @property
+ def token_key(self) -> str:
+ return "tok_z"
```
-## Example Usage
+## Complete Usage Example
-### Complete Pipeline
+Here's a comprehensive example showing the full workflow:
```python
import torch
from aion import AION
-from aion.modalities import Image, Spectrum
-from aion.codecs import ImageCodec, SpectrumCodec
+from aion.codecs import CodecManager
+from aion.modalities import (
+ LegacySurveyImage, DESISpectrum,
+ LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ
+)
+
+# 1. Load model and codec manager
+model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval()
+codec_manager = CodecManager(device='cuda')
-# Load model and codecs
-model = AION.from_pretrained('polymathic-ai/aion-base')
-image_codec = ImageCodec.from_pretrained('polymathic-ai/aion-image-codec')
-spectrum_codec = SpectrumCodec.from_pretrained('polymathic-ai/aion-spectrum-codec')
+# 2. Prepare data
+image = LegacySurveyImage(
+ flux=torch.tensor(image_data, dtype=torch.float32),
+ bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']
+)
+
+spectrum = DESISpectrum(
+ flux=torch.tensor(flux_data),
+ ivar=torch.tensor(ivar_data),
+ mask=torch.tensor(mask_data, dtype=torch.bool),
+ wavelength=torch.tensor(wavelength_data)
+)
-# Load data
-image = Image(flux=galaxy_flux, bands=['g', 'r', 'i', 'z', 'y'])
-spectrum = Spectrum(wavelength=wavelength, flux=flux)
+flux_g = LegacySurveyFluxG(value=torch.tensor([flux_g_value]))
-# Tokenize
-tokens = {
- 'image': image_codec.encode(image),
- 'spectrum': spectrum_codec.encode(spectrum)
-}
+# 3. Encode to tokens
+tokens = codec_manager.encode(image, spectrum, flux_g)
-# Encode to representations
+# 4. Extract embeddings for downstream tasks
with torch.no_grad():
- representations = model.encode(tokens)
+ embeddings = model.encode(tokens, num_encoder_tokens=600)
+ pooled_embeddings = embeddings.mean(dim=1) # [batch, hidden_dim]
-# Generate missing modalities
-results = model.generate(
- inputs={'image': image},
- targets=['spectrum', 'redshift']
+# 5. Predict redshift
+with torch.no_grad():
+ predictions = model(
+ tokens,
+ target_mask={"tok_z": torch.zeros(1, 1)},
+ num_encoder_tokens=600
+ )
+ redshift_probs = torch.softmax(predictions["tok_z"][0], dim=-1)
+
+# 6. Decode tokens back to modalities
+reconstructed_image = codec_manager.decode(
+ tokens,
+ LegacySurveyImage,
+ bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']
)
-
-# Decode results
-generated_spectrum = spectrum_codec.decode(results['spectrum'])
-print(f"Predicted redshift: {results['redshift'].value[0]:.3f}")
```
-## Error Handling
+## Model Variants
-All AION components include comprehensive error handling:
+Currently available pre-trained models:
-```python
-from aion.exceptions import (
- ModalityError, # Invalid modality data
- CodecError, # Tokenization failures
- ModelError, # Model inference errors
- DataError # Data loading issues
-)
+| Model | Parameters | HuggingFace ID |
+|-------|------------|----------------|
+| AION-Base | 300M | `polymathic-ai/aion-base` |
-try:
- result = model.generate(inputs, targets)
-except ModalityError as e:
- print(f"Invalid modality: {e}")
-except CodecError as e:
- print(f"Tokenization failed: {e}")
-```
+More model variants will be added as they become available.
+
+## Common Patterns
-## Performance Tips
+### Similarity Search
+```python
+def compute_similarities(query_tokens, database_tokens, model):
+ """Compute embedding similarities between query and database."""
+ with torch.no_grad():
+ query_emb = model.encode(query_tokens).mean(dim=1)
+ db_embs = model.encode(database_tokens).mean(dim=1)
+
+ from sklearn.metrics.pairwise import cosine_similarity
+ return cosine_similarity(query_emb.cpu(), db_embs.cpu())
+```
-1. **Batch Processing**: Always process multiple objects together when possible
-2. **Mixed Precision**: Use `torch.cuda.amp` for faster inference
-3. **Token Caching**: Reuse encoder outputs when generating multiple targets
-4. **Device Placement**: Use `.to(device)` consistently for all tensors
+### Batch Processing
+```python
+def process_batch(batch_data, model, codec_manager):
+ """Process a batch of astronomical objects."""
+ batch_tokens = codec_manager.encode(*batch_data)
-For more details, see the [Usage Guide](usage.md) and [Architecture](architecture.md) documentation.
+ with torch.no_grad():
+ embeddings = model.encode(batch_tokens, num_encoder_tokens=600)
-```{eval-rst}
-.. automodule:: aion
- :members:
- :undoc-members:
- :show-inheritance:
+ return embeddings.mean(dim=1) # Pooled embeddings
```
+
+For more examples, see the [Usage Guide](usage.md) and [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb).
diff --git a/docs/architecture.md b/docs/architecture.md
index 607c7f8..f4062cd 100644
--- a/docs/architecture.md
+++ b/docs/architecture.md
@@ -222,9 +222,9 @@ AION-1 comes in three sizes, each using the same architecture with different dim
| Model | Parameters | Encoder Layers | Decoder Layers | Hidden Dim | Attention Heads |
|-------|------------|----------------|----------------|------------|-----------------|
-| AION-1-B (Base) | 300M | 12 | 12 | 768 | 12 |
-| AION-1-L (Large) | 800M | 24 | 24 | 1024 | 16 |
-| AION-1-XL (XLarge) | 3.1B | 24 | 24 | 2048 | 32 |
+| AION-Base | ~300M | 12 | 12 | 768 | 12 |
+
+> **Note**: Additional model sizes may be released in the future. Current model ID: `polymathic-ai/aion-base`
All models use:
- SwiGLU activation functions
diff --git a/docs/index.md b/docs/index.md
index 0d44812..d414bae 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -3,65 +3,27 @@
AION-1
AstronomIcal Omnimodal Network
- The first large-scale multimodal foundation model for astronomy
+ Large-Scale Multimodal Foundation Model for Astronomy
```
-# Welcome to AION-1
-
-AION-1 (AstronomIcal Omnimodal Network) represents a breakthrough in astronomical machine learning: the first foundation model capable of understanding and processing arbitrary combinations of astronomical observations across 39 different data modalities. Trained on over 200 million astronomical objects, AION-1 unifies imaging, spectroscopy, photometry, and catalog data from major ground- and space-based observatories into a single, powerful framework.
+# AION-1 Documentation
## ๐ Why AION-1?
-Traditional approaches in astronomy treat each data modality in isolation, missing the rich interconnections between different types of observations. AION-1 fundamentally changes this paradigm by:
+Trained on over 200 million astronomical objects, AION-1 (AstronomIcal Omnimodal Network) is the first Foundation Model capable of unifying multiband imaging, spectroscopy, and photometry from major ground- and space-based observatories into a single framework.
-- **Learning Cross-Modal Relationships**: The model discovers how different observations relate to each other, building a deep understanding of the underlying astrophysical objects
+Compared to traditional machine learning approaches in Astronomy, AION-1 stands out on several points:
- **Enabling Flexible Data Fusion**: Scientists can use any combination of available observations without redesigning their analysis pipeline
+- **Enabling Easy Adaptation to Downstream Tasks**: Scientists can adapt AION-1 to new tasks in a matter of minutes and reach SOTA performance
- **Excelling in Low-Data Regimes**: AION-1 achieves competitive results with orders of magnitude less labeled data than supervised approaches
- **Providing Universal Representations**: The learned embeddings capture physically meaningful structure useful across diverse downstream tasks
-## ๐ Key Capabilities
-
-```{eval-rst}
-.. grid:: 1 1 2 3
- :gutter: 3
-
- .. grid-item-card:: ๐ 39 Data Modalities
- :class-card: feature-card
-
- Seamlessly integrates multiband images, optical spectra, photometry, and catalog data from HSC, Legacy Survey, SDSS, DESI, and Gaia
-
- .. grid-item-card:: ๐ง 200M+ Objects
- :class-card: feature-card
-
- Pre-trained on massive astronomical datasets spanning galaxies, stars, and quasars across multiple surveys
-
- .. grid-item-card:: ๐ง Flexible Architecture
- :class-card: feature-card
-
- Two-stage design with modality-specific tokenization followed by transformer-based multimodal masked modeling
-
- .. grid-item-card:: โก Emergent Behaviors
- :class-card: feature-card
-
- Demonstrates physical understanding, superior low-data performance, and meaningful latent space organization
-
- .. grid-item-card:: ๐ฏ Versatile Applications
- :class-card: feature-card
-
- Supports regression, classification, generation, retrieval, and cross-modal prediction tasks out-of-the-box
-
- .. grid-item-card:: ๐ Open Science
- :class-card: feature-card
-
- Fully open-source including datasets, training scripts, and model weights for reproducible research
-```
-
## ๐ Quick Start
Getting started with AION-1 is straightforward:
@@ -69,44 +31,42 @@ Getting started with AION-1 is straightforward:
```python
# Minimal end-to-end example
from aion import AION
-import numpy as np
+from aion.codecs import CodecManager
+from aion.modalities import (LegacySurveyImage, LegacySurveyFluxG,
+LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ)
+
+# 1) Load a pre-trained checkpoint (300 M parameters)
+model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval()
+codec_manager = CodecManager(device='cuda') # Manages codecs for each modality
+
+# 2) Prepare demo inputs (96ร96 g,r,i,z cut-out and photometry)
+# Create image modality
+image = LegacySurveyImage(
+ flux=data["legacysurvey_image_flux"],
+ bands=["DES-G", "DES-R", "DES-I", "DES-Z"],
+)
-# 1) Load a pre-trained checkpoint (800 M parameters)
-model = AION.from_pretrained('polymathic-ai/aion-base')
+# Create flux modalities
+g = LegacySurveyFluxG(value=data["legacysurvey_FLUX_G"])
+r = LegacySurveyFluxR(value=data["legacysurvey_FLUX_R"])
+i = LegacySurveyFluxI(value=data["legacysurvey_FLUX_I"])
+z = LegacySurveyFluxZ(value=data["legacysurvey_FLUX_Z"])
-# 2) Prepare demo inputs (96ร96 HSC g,r,i,z,y cut-out and SDSS spectrum)
-galaxy_image = np.load('hsc_cutout_5band.npy') # shape (5,96,96)
-galaxy_spectrum = np.load('sdss_spectrum.npy') # dict with wavelength/flux
+# Encode input modalities into tokens
+tokens = codec_manager.encode(image, g, r, i, z)
-# 3) Generate a high-resolution DESI-like spectrum from the image
-generated = model.generate(
- inputs={'image': galaxy_image},
- targets=['spectrum']
+# 3) Generate a redshift distribution from these set of inputs
+predictions = model(
+ tokens,
+ target_mask={"tok_z": torch.zeros(batch_size, 1)},
+ num_encoder_tokens=600
)
+redshift_logits = predictions["tok_z"] # Shape: [batch, sequence, vocab_size]
# 4) Extract joint embeddings for downstream use
-embeddings = model.encode({'image': galaxy_image, 'spectrum': galaxy_spectrum})
+embeddings = model.encode(tokens, num_encoder_tokens=600) # Shape: [batch, seq_len, hidden_dim]
```
-## ๐ฌ Scientific Impact
-
-AION-1 demonstrates several emergent behaviors that reflect its deep understanding of astronomical data:
-
-### Physical Understanding
-- Solves non-trivial scientific tasks using only simple linear probes on learned representations
-- Organizes objects in embedding space along physically meaningful dimensions
-- Captures relationships between disparate observations of the same physical phenomena
-
-### Performance Advantages
-- Achieves state-of-the-art results on galaxy property estimation, stellar parameter prediction, and morphology classification
-- Outperforms supervised baselines by 3x on rare object detection tasks
-- Enables accurate cross-modal prediction even for modality pairs never seen during training
-
-### Practical Benefits
-- Reduces data requirements by orders of magnitude for downstream tasks
-- Enables seamless integration of heterogeneous observations
-- Provides robust uncertainty quantification through multiple sampling
-
## ๐ Documentation Overview
```{eval-rst}
@@ -119,11 +79,11 @@ AION-1 demonstrates several emergent behaviors that reflect its deep understandi
Environment setup, dependencies, and configuration
- .. grid-item-card:: Model Architecture
+ .. grid-item-card:: Model Specifications
:link: architecture.html
:class-card: doc-card
- Deep dive into tokenization, transformers, and design
+ Deep dive into tokenization, transformers, and trarining data
.. grid-item-card:: Usage Guide
:link: usage.html
@@ -147,13 +107,3 @@ architecture
usage
api
```
-
-## ๐ค Join the Community
-
-```{raw} html
-
-```
diff --git a/docs/installation.md b/docs/installation.md
index f28e27c..6225e4d 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,89 +1,111 @@
# Installation Guide
-This comprehensive guide will walk you through installing AION-1 and setting up your environment for astronomical multimodal analysis.
+Quick and straightforward installation guide for AION-1.
## System Requirements
### Hardware Requirements
-AION-1 is designed to run efficiently on various hardware configurations:
+**Minimum (CPU only)**:
+- 16 GB RAM
+- 20 GB free storage
-- **Minimum Requirements**:
- - CPU: 4+ cores (Intel/AMD x86_64 or Apple Silicon)
- - RAM: 16 GB
- - GPU: NVIDIA GPU with 8GB+ VRAM (optional but recommended)
- - Storage: 50 GB free space for models and data
+**Recommended (GPU)**:
+- NVIDIA GPU with 8GB+ VRAM
+- 32 GB RAM
+- 50 GB free storage
-- **Recommended Requirements**:
- - CPU: 8+ cores
- - RAM: 32 GB or more
- - GPU: NVIDIA GPU with 24GB+ VRAM (e.g., RTX 3090, A5000, or better)
- - Storage: 100 GB+ free space
-
-- **For Large-Scale Processing**:
- - Multiple GPUs with NVLink
- - 64GB+ RAM
- - Fast SSD storage for data loading
+**For Large-Scale Processing**:
+- NVIDIA GPU with 24GB+ VRAM (e.g., RTX 4090, A5000+)
+- 64GB+ RAM
### Software Requirements
-- Python 3.10 or later
-- CUDA 11.8+ (for GPU support)
-- Operating System: Linux, macOS, or Windows
-
-## Installation Methods
+- Python 3.10+
+- CUDA 11.8+ (for GPU acceleration)
+- Linux, macOS, or Windows
-### 1. Quick Install via PyPI
+## Installation
-The simplest way to install AION-1 is through PyPI:
+### Quick Install (Recommended)
```bash
+# Install PyTorch with CUDA support (adjust CUDA version as needed)
+pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
+
+# Install AION
pip install aion
```
-This installs the core AION package with minimal dependencies.
-
-### 2. Full Installation with PyTorch
-
-For GPU support and optimal performance:
+### Alternative: CPU-only Installation
```bash
-# Install PyTorch first (adjust for your CUDA version)
-pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
+# Install CPU-only PyTorch
+pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
-# Then install AION
+# Install AION
pip install aion
```
-### 3. Development Installation
-
-For contributors or those who want the latest features:
+### Development Installation
```bash
-# Clone the repository
git clone https://github.com/polymathic-ai/aion.git
cd aion
+pip install -e ".[torch,dev]"
+```
+
+## Verification
+
+Test your installation:
-# Create a virtual environment
-python -m venv venv
-source venv/bin/activate # On Windows: venv\Scripts\activate
+```python
+import torch
+from aion import AION
+from aion.codecs import CodecManager
-# Install in development mode
-pip install -e ".[dev]"
+print(f"PyTorch version: {torch.__version__}")
+print(f"CUDA available: {torch.cuda.is_available()}")
+
+# Test model loading (requires internet connection)
+try:
+ model = AION.from_pretrained('polymathic-ai/aion-base')
+ print("โ AION model loaded successfully")
+except Exception as e:
+ print(f"โ Model loading failed: {e}")
+
+# Test codec manager
+try:
+ codec_manager = CodecManager(device='cuda' if torch.cuda.is_available() else 'cpu')
+ print("โ CodecManager initialized successfully")
+except Exception as e:
+ print(f"โ CodecManager failed: {e}")
```
-## Setting Up Your Environment
+## Troubleshooting
-### 1. Virtual Environment Setup
+### Common Issues
-We strongly recommend using a virtual environment:
+**CUDA out of memory**:
+```bash
+# Use smaller model or CPU
+model = AION.from_pretrained('polymathic-ai/aion-base').to('cpu')
+```
+**HuggingFace connection issues**:
```bash
-# Using venv
-python -m venv aion-env
-source aion-env/bin/activate # On Windows: aion-env\Scripts\activate
+# Set up HuggingFace cache directory
+export HF_HOME=/path/to/cache
+```
-# Using conda
-conda create -n aion python=3.10
-conda activate aion
+**Import errors**:
+```bash
+# Reinstall with fresh environment
+pip uninstall aion torch
+pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
+pip install aion
```
+
+## Next Steps
+
+Once installed, try the [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb) or check the [Usage Guide](usage.md) for examples.
diff --git a/docs/usage.md b/docs/usage.md
index 5d03545..855dbea 100644
--- a/docs/usage.md
+++ b/docs/usage.md
@@ -1,723 +1,609 @@
# AION-1 Usage Guide
-This comprehensive guide demonstrates how to use AION-1 for various astronomical analysis tasks. From basic inference to advanced multimodal generation, you'll learn to leverage AION-1's capabilities for your research.
+This comprehensive guide demonstrates how to use AION-1 for various astronomical analysis tasks, based on the actual working implementation.
## Table of Contents
1. [Quick Start](#quick-start)
-2. [Loading and Preprocessing Data](#loading-and-preprocessing-data)
-3. [Basic Inference](#basic-inference)
-4. [Multimodal Generation](#multimodal-generation)
-5. [Cross-Modal Translation](#cross-modal-translation)
-6. [Representation Learning](#representation-learning)
-7. [Advanced Applications](#advanced-applications)
-8. [Performance Optimization](#performance-optimization)
+2. [Loading and Preparing Data](#loading-and-preparing-data)
+3. [Basic Workflows](#basic-workflows)
+4. [Embedding Extraction](#embedding-extraction)
+5. [Similarity Search](#similarity-search)
+6. [Property Prediction](#property-prediction)
+7. [Performance Tips](#performance-tips)
## Quick Start
-Let's begin with a simple example that showcases AION-1's core capabilities:
+Here's how to get started with AION-1 in just a few lines:
```python
-import torch, numpy as np
+import torch
+import numpy as np
from aion import AION
-from aion.modalities import Image
+from aion.codecs import CodecManager
+from aion.modalities import LegacySurveyImage
+
+# 1. Load model and codec manager
+model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval()
+codec_manager = CodecManager(device='cuda')
-# 1) Load a checkpoint (300 M parameters)
-model = AION.from_pretrained('polymathic-ai/aion-tiny').eval()
+# 2. Prepare your astronomical data
+image = LegacySurveyImage(
+ flux=torch.tensor(your_image_data, dtype=torch.float32), # Shape: [batch, 4, height, width]
+ bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']
+)
-# 2) Read an example 5-band HSC cut-out (units: nanomaggies)
-flux_cube = np.load('hsc_cutout_5band.npy') # shape (5,96,96)
-img = Image(flux=flux_cube, bands=['HSC-G','HSC-R','HSC-I','HSC-Z','HSC-Y'])
+# 3. Encode to tokens
+tokens = codec_manager.encode(image)
-# 3) Predict an SDSS-like spectrum (observer-frame, erg sโปยน cmโปยฒ ร
โปยน)
-with torch.inference_mode():
- result = model.generate(inputs={'image': img}, targets=['spectrum'])
+# 4. Extract embeddings for downstream analysis
+with torch.no_grad():
+ embeddings = model.encode(tokens, num_encoder_tokens=600)
+ # Shape: [batch, sequence_length, 768]
-spec = result['spectrum']
-print(f"Generated spectrum: ฮป range {spec.wavelength[0]:.0f}-{spec.wavelength[-1]:.0f} ร
, shape={spec.flux.shape}")
+# 5. Predict redshift distribution
+with torch.no_grad():
+ predictions = model(
+ tokens,
+ target_mask={"tok_z": torch.zeros(batch_size, 1)},
+ num_encoder_tokens=600
+ )
+ redshift_logits = predictions["tok_z"]
+ redshift_probs = torch.softmax(redshift_logits, dim=-1)
```
-## Loading and Preprocessing Data
+## Loading and Preparing Data
### Working with Images
-AION-1 expects images in a specific format. Here's how to prepare astronomical images:
+AION-1 expects multi-band astronomical images with specific formatting:
```python
-import numpy as np
+import torch
from astropy.io import fits
-from aion.modalities import Image
-from aion.codecs.preprocessing import ImagePreprocessor
-
-# Load FITS data
-with fits.open('galaxy.fits') as hdul:
- # Assuming multi-band data in extensions
- flux_data = np.array([hdul[i].data for i in range(1, 6)]) # 5 bands
-
-# Create Image modality
-image = Image(
- flux=flux_data,
- bands=['HSC-G', 'HSC-R', 'HSC-I', 'HSC-Z', 'HSC-Y'],
- # Optional: provide inverse variance for optimal processing
- ivar=inverse_variance_data
-)
+from aion.modalities import LegacySurveyImage, HSCImage
+
+# Example 1: Legacy Survey (4-band: g,r,i,z)
+def load_legacy_survey_image(fits_path):
+ """Load and format Legacy Survey FITS data."""
+ with fits.open(fits_path) as hdul:
+ # Assuming bands are in separate extensions
+ flux_data = np.array([hdul[i].data for i in range(1, 5)]) # 4 bands
+
+ image = LegacySurveyImage(
+ flux=torch.tensor(flux_data, dtype=torch.float32),
+ bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']
+ )
+ return image
+
+# Example 2: HSC (5-band: g,r,i,z,y)
+def load_hsc_image(flux_array):
+ """Load HSC 5-band image data."""
+ image = HSCImage(
+ flux=torch.tensor(flux_array, dtype=torch.float32),
+ bands=['HSC-G', 'HSC-R', 'HSC-I', 'HSC-Z', 'HSC-Y']
+ )
+ return image
-# Apply survey-specific preprocessing
-preprocessor = ImagePreprocessor(survey='HSC')
-processed_image = preprocessor(image)
+# Note: AION-1 automatically crops/pads images to 96x96 pixels
```
### Working with Spectra
-Load and prepare spectroscopic data:
+Load and prepare spectroscopic observations:
```python
-from aion.modalities import Spectrum
-from astropy.io import fits
-
-# Load SDSS spectrum
-hdul = fits.open('spec-plate-mjd-fiber.fits')
-wavelength = 10**hdul[1].data['loglam'] # Convert log wavelength
-flux = hdul[1].data['flux']
-ivar = hdul[1].data['ivar']
-
-# Create Spectrum modality
-spectrum = Spectrum(
- wavelength=wavelength,
- flux=flux,
- ivar=ivar,
- survey='SDSS'
-)
-
-# The model handles resampling to internal wavelength grid automatically
+from aion.modalities import DESISpectrum, SDSSSpectrum
+
+def load_desi_spectrum(flux, ivar, mask, wavelength):
+ """Load DESI spectrum data."""
+ spectrum = DESISpectrum(
+ flux=torch.tensor(flux, dtype=torch.float32),
+ ivar=torch.tensor(ivar, dtype=torch.float32),
+ mask=torch.tensor(mask, dtype=torch.bool),
+ wavelength=torch.tensor(wavelength, dtype=torch.float32)
+ )
+ return spectrum
+
+def load_sdss_spectrum_from_fits(fits_path):
+ """Load SDSS spectrum from FITS file."""
+ with fits.open(fits_path) as hdul:
+ data = hdul[1].data
+ wavelength = 10**data['loglam'] # Convert from log wavelength
+ flux = data['flux']
+ ivar = data['ivar']
+
+ # Create mask for bad pixels
+ mask = (ivar > 0) & (flux > 0)
+
+ spectrum = SDSSSpectrum(
+ flux=torch.tensor(flux, dtype=torch.float32),
+ ivar=torch.tensor(ivar, dtype=torch.float32),
+ mask=torch.tensor(mask, dtype=torch.bool),
+ wavelength=torch.tensor(wavelength, dtype=torch.float32)
+ )
+ return spectrum
```
-### Working with Catalog Data
+### Working with Photometric Data
-Process tabular astronomical measurements:
+Prepare scalar measurements like fluxes and shape parameters:
```python
from aion.modalities import (
- FluxG, FluxR, FluxI, FluxZ,
- E1, E2, RadiusCARP, Redshift
+ LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ,
+ Z, GaiaParallax
)
-# Load catalog data (e.g., from pandas DataFrame)
-catalog_entry = {
- 'flux_g': FluxG(value=catalog_df['flux_g'].values),
- 'flux_r': FluxR(value=catalog_df['flux_r'].values),
- 'e1': E1(value=catalog_df['e1'].values),
- 'e2': E2(value=catalog_df['e2'].values),
- 'radius': RadiusCARP(value=catalog_df['radius'].values)
-}
-```
+def create_photometry_modalities(catalog_data):
+ """Create modalities from catalog measurements."""
+ modalities = []
-## Basic Inference
+ # Photometric fluxes
+ if 'flux_g' in catalog_data:
+ modalities.append(LegacySurveyFluxG(
+ value=torch.tensor(catalog_data['flux_g'], dtype=torch.float32)
+ ))
-### Single Modality Prediction
+ if 'flux_r' in catalog_data:
+ modalities.append(LegacySurveyFluxR(
+ value=torch.tensor(catalog_data['flux_r'], dtype=torch.float32)
+ ))
-Predict missing photometric measurements from available data:
+ # Redshift
+ if 'redshift' in catalog_data:
+ modalities.append(Z(
+ value=torch.tensor(catalog_data['redshift'], dtype=torch.float32)
+ ))
-```python
-# Given g,r,i bands, predict z band
-inputs = {
- 'flux_g': FluxG(value=[19.5]),
- 'flux_r': FluxR(value=[18.2]),
- 'flux_i': FluxI(value=[17.8])
-}
-
-# Predict z-band flux
-with torch.no_grad():
- predictions = model.generate(
- inputs=inputs,
- targets=['flux_z']
- )
-
-z_flux = predictions['flux_z'].value[0]
-print(f"Predicted z-band flux: {z_flux:.2f}")
+ return modalities
```
-### Batch Processing
-
-Process multiple objects efficiently:
+## Basic Workflows
-```python
-# Prepare batch of galaxies
-batch_images = [load_galaxy(i) for i in range(32)]
-batch = {
- 'image': Image.batch(batch_images)
-}
+### Workflow 1: Embedding Extraction
-# Generate properties for all galaxies
-with torch.no_grad():
- results = model.generate(
- inputs=batch,
- targets=['redshift', 'e1', 'e2', 'radius']
- )
-
-# Extract results
-redshifts = results['redshift'].value
-ellipticities = np.sqrt(results['e1'].value**2 + results['e2'].value**2)
-```
-
-## Multimodal Generation
-
-### Conditional Generation
-
-Generate multiple modalities conditioned on partial observations:
+Extract learned representations for downstream machine learning:
```python
-# Complex multimodal generation example
-def analyze_galaxy(image_path, known_redshift=None):
- # Load image
- image = load_and_preprocess_image(image_path)
-
- inputs = {'image': image}
- if known_redshift:
- inputs['redshift'] = Redshift(value=[known_redshift])
-
- # Generate comprehensive analysis
- targets = [
- 'spectrum', # Full spectrum
- 'flux_g', 'flux_r', 'flux_i', 'flux_z', # Photometry
- 'e1', 'e2', # Shape parameters
- 'radius', # Size
- 'parallax', # Distance indicator
- 'extinction_v' # Dust extinction
- ]
+def extract_galaxy_embeddings(data_list, model, codec_manager):
+ """Extract embeddings from a list of galaxy observations."""
+ all_embeddings = []
- with torch.no_grad():
- results = model.generate(
- inputs=inputs,
- targets=targets,
- num_generations=1,
- temperature=1.0
- )
-
- return results
+ # Process in batches for efficiency
+ batch_size = 32
+ for i in range(0, len(data_list), batch_size):
+ batch = data_list[i:i + batch_size]
-# Analyze a galaxy
-galaxy_properties = analyze_galaxy('ngc1234.fits', known_redshift=0.05)
-```
+ # Encode all modalities in the batch
+ batch_tokens = codec_manager.encode(*batch)
-### Uncertainty Quantification
+ # Extract embeddings
+ with torch.no_grad():
+ embeddings = model.encode(batch_tokens, num_encoder_tokens=600)
+ # Pool over sequence dimension
+ pooled = embeddings.mean(dim=1) # [batch, 768]
-Generate multiple samples to estimate uncertainties:
+ all_embeddings.append(pooled.cpu().numpy())
-```python
-def estimate_uncertainty(inputs, target, num_samples=100):
- samples = []
+ return np.vstack(all_embeddings)
- with torch.no_grad():
- for _ in range(num_samples):
- result = model.generate(
- inputs=inputs,
- targets=[target],
- temperature=1.2 # Higher temperature for more diversity
- )
- samples.append(result[target].value[0])
-
- samples = np.array(samples)
- return {
- 'mean': np.mean(samples),
- 'std': np.std(samples),
- 'percentiles': np.percentile(samples, [16, 50, 84])
- }
-
-# Estimate redshift uncertainty
-z_stats = estimate_uncertainty(
- inputs={'image': galaxy_image},
- target='redshift'
+# Usage example
+galaxy_embeddings = extract_galaxy_embeddings(
+ [image1, image2, image3, ...],
+ model,
+ codec_manager
)
-print(f"Redshift: {z_stats['mean']:.3f} ยฑ {z_stats['std']:.3f}")
```
-## Cross-Modal Translation
-
-### Image to Spectrum
+### Workflow 2: Redshift Prediction
-Convert imaging observations to spectroscopic predictions:
+Predict redshift distributions from various input modalities:
```python
-def image_to_spectrum(image, wavelength_range=(3800, 9200)):
- """Generate spectrum from multi-band image."""
+def predict_redshift_distribution(inputs, model, codec_manager):
+ """Predict redshift probability distribution."""
+ # Encode inputs
+ tokens = codec_manager.encode(*inputs)
- # Generate spectrum tokens
+ # Predict redshift
with torch.no_grad():
- result = model.generate(
- inputs={'image': image},
- targets=['spectrum']
+ predictions = model(
+ tokens,
+ target_mask={"tok_z": torch.zeros(len(inputs), 1)},
+ num_encoder_tokens=600
)
- spectrum = result['spectrum']
-
- # Filter to desired wavelength range
- mask = (spectrum.wavelength >= wavelength_range[0]) & \
- (spectrum.wavelength <= wavelength_range[1])
+ # Convert to probabilities
+ redshift_logits = predictions["tok_z"]
+ redshift_probs = torch.softmax(redshift_logits, dim=-1)
- return {
- 'wavelength': spectrum.wavelength[mask],
- 'flux': spectrum.flux[mask]
- }
+ return redshift_probs
-# Generate and plot spectrum
-synthetic_spec = image_to_spectrum(galaxy_image)
-plt.plot(synthetic_spec['wavelength'], synthetic_spec['flux'])
-plt.xlabel('Wavelength (ร
)')
-plt.ylabel('Flux')
-plt.title('AION-1 Generated Spectrum from Image')
+# Example: Predict from photometry
+redshift_dist = predict_redshift_distribution(
+ [flux_g, flux_r, flux_i, flux_z],
+ model,
+ codec_manager
+)
```
-### Spectrum to Image
+### Workflow 3: Reconstruction
-Inverse translation - generate images from spectra:
+Reconstruct modalities through the encode-decode process:
```python
-def spectrum_to_image(spectrum, bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']):
- """Generate multi-band image from spectrum."""
-
- with torch.no_grad():
- result = model.generate(
- inputs={'spectrum': spectrum},
- targets=['image'],
- target_bands=bands
- )
+def reconstruct_modality(original_modality, model, codec_manager, modality_class, **metadata):
+ """Reconstruct a modality through encode-decode cycle."""
+ # Encode original
+ tokens = codec_manager.encode(original_modality)
+
+ # Decode back
+ reconstructed = codec_manager.decode(
+ tokens,
+ modality_class,
+ **metadata
+ )
- return result['image']
+ return reconstructed
-# Reconstruct galaxy appearance
-reconstructed_image = spectrum_to_image(observed_spectrum)
+# Example: Reconstruct image
+reconstructed_image = reconstruct_modality(
+ original_image,
+ model,
+ codec_manager,
+ LegacySurveyImage,
+ bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']
+)
```
-### Super-Resolution
+## Embedding Extraction
-Enhance low-resolution spectra using multimodal context:
+### Basic Embedding Extraction
```python
-def enhance_spectrum(low_res_spectrum, supporting_data=None):
- """Enhance spectrum resolution using additional modalities."""
-
- inputs = {'spectrum': low_res_spectrum}
+def get_embeddings(modalities, model, codec_manager, pooling='mean'):
+ """Extract embeddings with different pooling strategies."""
+ tokens = codec_manager.encode(*modalities)
- # Add supporting data if available
- if supporting_data:
- inputs.update(supporting_data)
-
- # Generate high-resolution version
with torch.no_grad():
- result = model.generate(
- inputs=inputs,
- targets=['spectrum_highres'],
- num_generations=1
- )
+ embeddings = model.encode(tokens, num_encoder_tokens=600)
+
+ # Apply pooling
+ if pooling == 'mean':
+ return embeddings.mean(dim=1)
+ elif pooling == 'max':
+ return embeddings.max(dim=1)[0]
+ elif pooling == 'cls':
+ return embeddings[:, 0] # First token
+ else:
+ return embeddings # Return full sequence
- return result['spectrum_highres']
-
-# Example with photometric support
-enhanced = enhance_spectrum(
- sdss_spectrum,
- supporting_data={
- 'flux_g': FluxG(value=[18.5]),
- 'flux_r': FluxR(value=[17.2])
- }
-)
+# Usage
+embeddings = get_embeddings([image, spectrum], model, codec_manager)
```
-## Representation Learning
-
-### Extracting Embeddings
+### Multi-Modal Embeddings
-Use AION-1's learned representations for downstream tasks:
+Combine embeddings from different modalities:
```python
-def extract_embeddings(data_dict, pool='mean'):
- """Extract feature embeddings from AION-1 encoder."""
-
- # Tokenize inputs
- tokens = model.tokenize(data_dict)
+def get_multimodal_embeddings(image, spectrum, photometry, model, codec_manager):
+ """Extract embeddings from multiple modality types."""
- # Get encoder representations
- with torch.no_grad():
- embeddings = model.encode(tokens)
-
- # Pool over sequence dimension
- if pool == 'mean':
- features = embeddings.mean(dim=1)
- elif pool == 'cls':
- features = embeddings[:, 0] # First token
- elif pool == 'max':
- features = embeddings.max(dim=1)[0]
-
- return features.cpu().numpy()
-
-# Extract features for clustering
-galaxy_features = extract_embeddings({
- 'image': galaxy_image,
- 'spectrum': galaxy_spectrum
-})
-```
+ # Get embeddings from each modality type
+ image_tokens = codec_manager.encode(image)
+ spectrum_tokens = codec_manager.encode(spectrum)
+ photo_tokens = codec_manager.encode(*photometry)
-### Similarity Search
+ embeddings = {}
-Find similar objects using learned representations:
+ with torch.no_grad():
+ # Image embeddings
+ img_emb = model.encode(image_tokens, num_encoder_tokens=300)
+ embeddings['image'] = img_emb.mean(dim=1)
-```python
-from sklearn.metrics.pairwise import cosine_similarity
+ # Spectrum embeddings
+ spec_emb = model.encode(spectrum_tokens, num_encoder_tokens=300)
+ embeddings['spectrum'] = spec_emb.mean(dim=1)
-class GalaxySimilaritySearch:
- def __init__(self, model):
- self.model = model
- self.database = []
- self.embeddings = []
-
- def add_galaxy(self, galaxy_data, metadata=None):
- """Add galaxy to search database."""
- embedding = extract_embeddings(galaxy_data)
- self.embeddings.append(embedding)
- self.database.append({
- 'data': galaxy_data,
- 'metadata': metadata,
- 'embedding': embedding
- })
-
- def find_similar(self, query_data, k=10):
- """Find k most similar galaxies."""
- query_embedding = extract_embeddings(query_data)
-
- # Compute similarities
- similarities = cosine_similarity(
- query_embedding.reshape(1, -1),
- np.vstack(self.embeddings)
- )[0]
-
- # Get top k
- indices = np.argsort(similarities)[::-1][:k]
-
- return [(self.database[i], similarities[i]) for i in indices]
+ # Combined embeddings
+ all_tokens = {**image_tokens, **spectrum_tokens, **photo_tokens}
+ combined_emb = model.encode(all_tokens, num_encoder_tokens=900)
+ embeddings['combined'] = combined_emb.mean(dim=1)
-# Usage
-searcher = GalaxySimilaritySearch(model)
-# ... add galaxies to database ...
-similar_galaxies = searcher.find_similar(query_galaxy, k=5)
+ return embeddings
```
-### Anomaly Detection
+## Similarity Search
-Identify unusual objects using reconstruction error:
+Implement similarity search using AION embeddings:
```python
-def detect_anomalies(galaxies, threshold_percentile=95):
- """Detect anomalous galaxies using reconstruction error."""
-
- reconstruction_errors = []
+from sklearn.metrics.pairwise import cosine_similarity
+from sklearn.neighbors import NearestNeighbors
- for galaxy in galaxies:
- # Encode and decode
+class AIONSimilaritySearch:
+ def __init__(self, model, codec_manager):
+ self.model = model
+ self.codec_manager = codec_manager
+ self.database_embeddings = []
+ self.database_objects = []
+ self.index = None
+
+ def add_objects(self, objects):
+ """Add objects to the search database."""
+ for obj in objects:
+ # Extract embedding
+ tokens = self.codec_manager.encode(*obj['modalities'])
+ with torch.no_grad():
+ emb = self.model.encode(tokens, num_encoder_tokens=600)
+ emb = emb.mean(dim=1).cpu().numpy()
+
+ self.database_embeddings.append(emb)
+ self.database_objects.append(obj)
+
+ # Build search index
+ if self.database_embeddings:
+ embeddings_matrix = np.vstack(self.database_embeddings)
+ self.index = NearestNeighbors(n_neighbors=10, metric='cosine')
+ self.index.fit(embeddings_matrix)
+
+ def search(self, query_modalities, k=5):
+ """Search for similar objects."""
+ # Get query embedding
+ tokens = self.codec_manager.encode(*query_modalities)
with torch.no_grad():
- reconstructed = model.generate(
- inputs=galaxy,
- targets=list(galaxy.keys())
- )
-
- # Compute reconstruction error
- error = 0
- for key in galaxy:
- if key == 'image':
- error += np.mean((galaxy[key].flux -
- reconstructed[key].flux)**2)
- elif hasattr(galaxy[key], 'value'):
- error += np.mean((galaxy[key].value -
- reconstructed[key].value)**2)
-
- reconstruction_errors.append(error)
-
- # Set threshold
- threshold = np.percentile(reconstruction_errors, threshold_percentile)
-
- # Identify anomalies
- anomalies = [g for g, e in zip(galaxies, reconstruction_errors)
- if e > threshold]
-
- return anomalies, reconstruction_errors
+ query_emb = self.model.encode(tokens, num_encoder_tokens=600)
+ query_emb = query_emb.mean(dim=1).cpu().numpy()
+
+ # Find nearest neighbors
+ distances, indices = self.index.kneighbors(query_emb, n_neighbors=k)
+
+ results = []
+ for i, idx in enumerate(indices[0]):
+ results.append({
+ 'object': self.database_objects[idx],
+ 'similarity': 1 - distances[0][i], # Convert distance to similarity
+ 'rank': i + 1
+ })
+
+ return results
+
+# Usage example
+searcher = AIONSimilaritySearch(model, codec_manager)
+
+# Add objects to database
+database_objects = [
+ {'modalities': [image1, spectrum1], 'metadata': {'id': 'galaxy_1'}},
+ {'modalities': [image2, spectrum2], 'metadata': {'id': 'galaxy_2'}},
+ # ... more objects
+]
+searcher.add_objects(database_objects)
+
+# Search for similar objects
+query_galaxy = [query_image, query_spectrum]
+similar_objects = searcher.search(query_galaxy, k=10)
+
+print(f"Found {len(similar_objects)} similar objects:")
+for result in similar_objects:
+ print(f"Rank {result['rank']}: {result['object']['metadata']['id']} "
+ f"(similarity: {result['similarity']:.3f})")
```
-## Advanced Applications
+## Property Prediction
-### Multi-Survey Integration
+Use AION embeddings for various prediction tasks:
-Combine observations from different surveys:
+### Redshift Estimation with k-NN
```python
-def integrate_multi_survey(hsc_image, sdss_spectrum, desi_spectrum=None):
- """Integrate observations from multiple surveys."""
+from sklearn.neighbors import KNeighborsRegressor
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import mean_absolute_error, r2_score
- inputs = {
- 'image': hsc_image,
- 'spectrum_sdss': sdss_spectrum
- }
+def train_redshift_predictor(galaxies_with_redshifts, model, codec_manager):
+ """Train a k-NN regressor for redshift prediction."""
- if desi_spectrum:
- inputs['spectrum_desi'] = desi_spectrum
+ # Extract embeddings and targets
+ embeddings = []
+ redshifts = []
- # Generate unified representation
- with torch.no_grad():
- # Extract all available properties
- results = model.generate(
- inputs=inputs,
- targets=['redshift', 'stellar_mass', 'sfr', 'metallicity']
- )
+ for galaxy in galaxies_with_redshifts:
+ tokens = codec_manager.encode(*galaxy['modalities'])
+ with torch.no_grad():
+ emb = model.encode(tokens, num_encoder_tokens=600)
+ emb = emb.mean(dim=1).cpu().numpy()
- # Generate missing modalities
- if not desi_spectrum:
- results['spectrum_desi'] = model.generate(
- inputs=inputs,
- targets=['spectrum_desi']
- )['spectrum_desi']
+ embeddings.append(emb[0]) # Remove batch dimension
+ redshifts.append(galaxy['redshift'])
- return results
-```
+ X = np.array(embeddings)
+ y = np.array(redshifts)
-### Time Series Analysis
+ # Split data
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, test_size=0.2, random_state=42
+ )
-Analyze variable objects across epochs:
+ # Train k-NN regressor
+ knn = KNeighborsRegressor(n_neighbors=5)
+ knn.fit(X_train, y_train)
-```python
-def analyze_variable_object(observations):
- """
- Analyze time-variable astronomical object.
+ # Evaluate
+ y_pred = knn.predict(X_test)
+ mae = mean_absolute_error(y_test, y_pred)
+ r2 = r2_score(y_test, y_pred)
- observations: list of (time, data_dict) tuples
- """
+ print(f"Redshift prediction - MAE: {mae:.4f}, Rยฒ: {r2:.4f}")
- embeddings_over_time = []
- properties_over_time = []
+ return knn
- for time, data in observations:
- # Extract embeddings
- embedding = extract_embeddings(data)
- embeddings_over_time.append(embedding)
+def predict_redshift(new_galaxy, trained_model, model, codec_manager):
+ """Predict redshift for a new galaxy."""
+ tokens = codec_manager.encode(*new_galaxy)
+ with torch.no_grad():
+ emb = model.encode(tokens, num_encoder_tokens=600)
+ emb = emb.mean(dim=1).cpu().numpy()
- # Predict properties
- with torch.no_grad():
- props = model.generate(
- inputs=data,
- targets=['flux_g', 'flux_r', 'temperature']
- )
-
- properties_over_time.append({
- 'time': time,
- 'properties': props,
- 'embedding': embedding
- })
-
- # Analyze evolution
- embeddings = np.vstack(embeddings_over_time)
-
- # Detect significant changes
- embedding_distances = np.sqrt(np.sum(np.diff(embeddings, axis=0)**2, axis=1))
- change_points = np.where(embedding_distances > np.std(embedding_distances) * 2)[0]
-
- return {
- 'properties': properties_over_time,
- 'change_points': change_points,
- 'embedding_evolution': embeddings
- }
+ predicted_z = trained_model.predict(emb)[0]
+ return predicted_z
```
-### Physical Parameter Estimation
-
-Estimate astrophysical parameters with uncertainty:
+### Stellar Mass Prediction
```python
-class PhysicalParameterEstimator:
- def __init__(self, model, num_samples=100):
- self.model = model
- self.num_samples = num_samples
-
- def estimate_parameters(self, observations):
- """Estimate physical parameters with uncertainties."""
+from sklearn.ensemble import RandomForestRegressor
- # Parameters to estimate
- parameters = [
- 'redshift', 'stellar_mass', 'sfr',
- 'metallicity', 'age', 'extinction_v'
- ]
+def train_stellar_mass_predictor(galaxies_with_masses, model, codec_manager):
+ """Train predictor for stellar mass estimation."""
- # Generate multiple samples
- samples = {param: [] for param in parameters}
+ # Similar to redshift prediction but for stellar mass
+ embeddings = []
+ masses = []
+ for galaxy in galaxies_with_masses:
+ tokens = codec_manager.encode(*galaxy['modalities'])
with torch.no_grad():
- for _ in range(self.num_samples):
- results = self.model.generate(
- inputs=observations,
- targets=parameters,
- temperature=1.1
- )
-
- for param in parameters:
- if param in results:
- samples[param].append(results[param].value[0])
-
- # Compute statistics
- estimates = {}
- for param, values in samples.items():
- if values:
- values = np.array(values)
- estimates[param] = {
- 'median': np.median(values),
- 'mean': np.mean(values),
- 'std': np.std(values),
- 'ci_68': np.percentile(values, [16, 84]),
- 'ci_95': np.percentile(values, [2.5, 97.5])
- }
-
- return estimates
+ emb = model.encode(tokens, num_encoder_tokens=600)
+ emb = emb.mean(dim=1).cpu().numpy()
-# Usage
-estimator = PhysicalParameterEstimator(model)
-parameters = estimator.estimate_parameters({
- 'image': galaxy_image,
- 'spectrum': galaxy_spectrum
-})
-
-print(f"Stellar Mass: {parameters['stellar_mass']['median']:.2e} "
- f"+/- {parameters['stellar_mass']['std']:.2e} M_sun")
-```
+ embeddings.append(emb[0])
+ masses.append(np.log10(galaxy['stellar_mass'])) # Log stellar mass
-## Performance Optimization
-
-### Efficient Batch Processing
-
-```python
-from torch.utils.data import DataLoader, Dataset
+ X = np.array(embeddings)
+ y = np.array(masses)
-class AIONDataset(Dataset):
- def __init__(self, data_list):
- self.data = data_list
+ # Train Random Forest
+ rf = RandomForestRegressor(n_estimators=100, random_state=42)
+ rf.fit(X, y)
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, idx):
- return self.data[idx]
-
-def process_large_dataset(data_list, batch_size=32):
- """Efficiently process large datasets."""
-
- dataset = AIONDataset(data_list)
- dataloader = DataLoader(dataset, batch_size=batch_size,
- num_workers=4, pin_memory=True)
+ return rf
+```
- all_results = []
+## Performance Tips
- with torch.no_grad():
- for batch in dataloader:
- # Process batch
- results = model.generate(
- inputs=batch,
- targets=['redshift', 'stellar_mass']
- )
- all_results.append(results)
-
- # Concatenate results
- return {k: np.concatenate([r[k].value for r in all_results])
- for k in all_results[0]}
-```
+### Batch Processing
-### Memory-Efficient Processing
+Process multiple objects efficiently:
```python
-def process_with_chunking(large_spectrum, chunk_size=1000):
- """Process very long spectra in chunks."""
+def process_batch_efficiently(object_list, model, codec_manager, batch_size=32):
+ """Process objects in batches for better GPU utilization."""
+ results = []
- n_chunks = len(large_spectrum.wavelength) // chunk_size + 1
- chunk_results = []
+ for i in range(0, len(object_list), batch_size):
+ batch = object_list[i:i + batch_size]
- for i in range(n_chunks):
- start = i * chunk_size
- end = min((i + 1) * chunk_size, len(large_spectrum.wavelength))
+ # Group by modality type for efficient encoding
+ images = [obj for obj in batch if 'image' in obj]
+ spectra = [obj for obj in batch if 'spectrum' in obj]
- chunk = Spectrum(
- wavelength=large_spectrum.wavelength[start:end],
- flux=large_spectrum.flux[start:end]
- )
+ batch_results = []
with torch.no_grad():
- result = model.process_spectrum_chunk(chunk)
- chunk_results.append(result)
+ # Process images
+ if images:
+ image_batch = [obj['image'] for obj in images]
+ tokens = codec_manager.encode(*image_batch)
+ embeddings = model.encode(tokens, num_encoder_tokens=600)
+ batch_results.extend(embeddings.mean(dim=1).cpu().numpy())
+
+ # Process spectra
+ if spectra:
+ spectrum_batch = [obj['spectrum'] for obj in spectra]
+ tokens = codec_manager.encode(*spectrum_batch)
+ embeddings = model.encode(tokens, num_encoder_tokens=300)
+ batch_results.extend(embeddings.mean(dim=1).cpu().numpy())
+
+ results.extend(batch_results)
- # Combine chunks
- return combine_spectrum_chunks(chunk_results)
+ return results
```
-### GPU Memory Management
+### Memory Management
-```python
-import gc
+Handle large datasets with limited GPU memory:
-def memory_efficient_generation(inputs, targets, max_batch=16):
- """Generate with automatic batch size adjustment."""
+```python
+def process_large_dataset(dataset, model, codec_manager, max_batch_size=16):
+ """Process large datasets with automatic memory management."""
+ import gc
- batch_size = max_batch
+ current_batch_size = max_batch_size
+ results = []
- while batch_size > 0:
+ i = 0
+ while i < len(dataset):
try:
+ batch = dataset[i:i + current_batch_size]
+
+ # Process batch
+ batch_tokens = codec_manager.encode(*batch)
with torch.no_grad():
- results = model.generate(
- inputs=inputs,
- targets=targets,
- batch_size=batch_size
- )
- return results
+ embeddings = model.encode(batch_tokens, num_encoder_tokens=600)
+ results.append(embeddings.mean(dim=1).cpu())
+
+ i += current_batch_size
except torch.cuda.OutOfMemoryError:
- # Clear cache and try smaller batch
+ # Clear memory and reduce batch size
torch.cuda.empty_cache()
gc.collect()
- batch_size //= 2
+ current_batch_size = max(1, current_batch_size // 2)
+ print(f"Reduced batch size to {current_batch_size}")
- if batch_size == 0:
- raise RuntimeError("Cannot fit even batch size 1")
+ if current_batch_size == 0:
+ raise RuntimeError("Cannot process even single example")
- raise RuntimeError("Failed to process")
+ return torch.cat(results, dim=0)
```
-## Best Practices
+### Using Mixed Precision
-### 1. Data Preparation
-- Always normalize and preprocess data according to survey specifications
-- Provide inverse variance when available for optimal results
-- Use appropriate data types for each modality
+Speed up inference with automatic mixed precision:
-### 2. Model Selection
-- Use `aion-tiny` for quick experiments and limited GPU memory
-- Use `aion-base` for most research applications
-- Use `aion-large` for highest accuracy when computational resources permit
+```python
+def extract_embeddings_amp(modalities, model, codec_manager):
+ """Extract embeddings using automatic mixed precision."""
+ from torch.cuda.amp import autocast
-### 3. Generation Settings
-- Lower temperature (0.8-1.0) for more deterministic outputs
-- Higher temperature (1.1-1.5) for diversity and uncertainty estimation
-- Multiple generations for robust uncertainty quantification
+ tokens = codec_manager.encode(*modalities)
-### 4. Error Handling
-```python
-def safe_generate(model, inputs, targets, fallback=None):
- """Safely generate with error handling."""
- try:
- return model.generate(inputs=inputs, targets=targets)
- except Exception as e:
- print(f"Generation failed: {e}")
- return fallback or {t: None for t in targets}
+ with torch.no_grad():
+ with autocast():
+ embeddings = model.encode(tokens, num_encoder_tokens=600)
+
+ return embeddings.float() # Convert back to float32
```
-## Conclusion
+## Best Practices
+
+1. **Always use `.eval()` mode** for inference to disable dropout and batch norm updates
+2. **Use `torch.no_grad()`** to disable gradient computation and save memory
+3. **Process in batches** when possible for better GPU utilization
+4. **Pool embeddings appropriately** - mean pooling works well for most tasks
+5. **Use consistent device placement** - ensure all tensors are on the same device
+6. **Clear GPU cache** periodically when processing large datasets
+
+## Troubleshooting
+
+### Common Issues
-AION-1 provides a powerful and flexible framework for multimodal astronomical analysis. Its ability to seamlessly integrate diverse observations enables new research possibilities:
+1. **CUDA out of memory**: Reduce batch size or use gradient checkpointing
+2. **Slow processing**: Ensure data is on GPU and use batch processing
+3. **Shape mismatches**: Check that tensor dimensions match expected format
+4. **Device errors**: Ensure model, data, and codec_manager are on same device
-- Cross-modal prediction and generation
-- Unified analysis across multiple surveys
-- Robust uncertainty quantification
-- Discovery of unusual objects
-- Efficient processing of large datasets
+### Debug Mode
+
+```python
+def debug_tokens(tokens, codec_manager):
+ """Debug token shapes and contents."""
+ print("Token summary:")
+ for key, tensor in tokens.items():
+ print(f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}")
+ print(f" range: [{tensor.min().item():.2f}, {tensor.max().item():.2f}]")
+```
-For more examples and the latest updates, visit the [AION GitHub repository](https://github.com/polymathic-ai/aion) and join our community discussions.
+For more advanced examples and the latest updates, see the [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb).