diff --git a/docs/api.md b/docs/api.md index 8382838..f861f32 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1,691 +1,363 @@ # API Reference -This comprehensive API reference covers all major components of AION-1, including modalities, codecs, models, and utilities. +This API reference covers the core components you'll actually use with AION-1, based on the working implementation. ## Core Model ### `aion.AION` -The main AION model class that provides high-level interfaces for multimodal astronomical analysis. +The main AION model class that provides multimodal astronomical analysis. ```python -class AION(FourM): +from aion import AION + +class AION(FM): """ AION-1 multimodal astronomical foundation model. - Inherits from FourM architecture and adds astronomical-specific - functionality for processing 39 different data modalities. + Inherits from FM (4M) architecture and adds astronomical-specific + functionality for processing multiple data modalities. """ @classmethod - def from_pretrained( - cls, - model_name: str, - device: str = 'cuda', - torch_dtype: torch.dtype = torch.float32, - **kwargs - ) -> 'AION': + def from_pretrained(cls, model_name: str, **kwargs) -> 'AION': """ - Load a pre-trained AION model. + Load a pre-trained AION model from HuggingFace Hub. Args: model_name: HuggingFace model identifier - - 'polymathic-ai/aion-tiny': 300M parameter model - - 'polymathic-ai/aion-base': 800M parameter model - - 'polymathic-ai/aion-large': 3.1B parameter model - device: Device to load model on ('cuda', 'cpu', 'mps') - torch_dtype: Data type for model weights - **kwargs: Additional arguments passed to model constructor + - 'polymathic-ai/aion-base': 300M parameter model Returns: AION model instance + + Example: + >>> model = AION.from_pretrained('polymathic-ai/aion-base') + >>> model = model.to('cuda').eval() """ - def generate( + def forward( self, - inputs: Dict[str, Modality], - targets: List[str], - num_generations: int = 1, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None - ) -> Dict[str, Modality]: + input_tokens: Dict[str, torch.Tensor], + target_mask: Optional[Dict[str, torch.Tensor]] = None, + num_encoder_tokens: int = 600, + **kwargs + ) -> Dict[str, torch.Tensor]: """ - Generate target modalities from input observations. - - Note: - ``targets`` must be chosen from the list returned by - ``AION.supported_targets`` (essentially the 39 modality names - listed in the architecture documentation). Supplying an - unsupported string will raise ``ValueError``. + Forward pass through the model. Args: - inputs: Dictionary mapping modality names to data - targets: List of modality names to generate - num_generations: Number of samples to generate - temperature: Sampling temperature (higher = more diverse) - top_k: Top-k sampling parameter - top_p: Nucleus sampling parameter + input_tokens: Dictionary mapping modality token keys to token tensors + target_mask: Dictionary specifying which tokens to predict + Format: {"tok_z": torch.zeros(batch_size, num_target_tokens)} + num_encoder_tokens: Number of tokens to use in encoder Returns: - Dictionary mapping target names to generated modalities + Dictionary mapping target keys to prediction logits + + Example: + >>> predictions = model( + ... tokens, + ... target_mask={"tok_z": torch.zeros(32, 1)}, + ... num_encoder_tokens=600 + ... ) + >>> redshift_probs = torch.softmax(predictions["tok_z"], dim=-1) """ def encode( self, - inputs: Dict[str, torch.Tensor] + input_tokens: Dict[str, torch.Tensor], + num_encoder_tokens: int = 600 ) -> torch.Tensor: """ - Encode input tokens to learned representations. + Extract embeddings from input tokens. Args: - inputs: Tokenized inputs + input_tokens: Dictionary of tokenized modality data + num_encoder_tokens: Number of tokens for encoder processing Returns: - Encoder hidden states [batch, seq_len, hidden_dim] - """ + Encoder embeddings with shape [batch, seq_len, hidden_dim] - def tokenize( - self, - modalities: Dict[str, Modality] - ) -> Dict[str, torch.Tensor]: + Example: + >>> embeddings = model.encode(tokens, num_encoder_tokens=600) + >>> # Use embeddings for downstream tasks + >>> pooled = embeddings.mean(dim=1) # [batch, hidden_dim] """ - Convert modalities to discrete tokens using codecs. - - Args: - modalities: Dictionary of modality data - - Returns: - Dictionary of tokenized tensors - """ -``` - -## Modalities - -AION-1 supports 39 different astronomical data modalities. Each modality is represented by a Pydantic model ensuring type safety and validation. - -### Image Modalities - -#### `aion.modalities.Image` - -```python -class Image(Modality): - """ - Multi-band astronomical image. - - Attributes: - flux: Image data array [bands, height, width] - bands: List of band identifiers (e.g., ['HSC-G', 'HSC-R']) - ivar: Optional inverse variance array for weighting - mask: Optional boolean mask array - """ - - flux: np.ndarray - bands: List[str] - ivar: Optional[np.ndarray] = None - mask: Optional[np.ndarray] = None - - @classmethod - def batch(cls, images: List['Image']) -> 'Image': - """Batch multiple images together.""" - - def crop(self, size: int = 96) -> 'Image': - """Center crop image to specified size.""" -``` - -### Spectrum Modalities - -#### `aion.modalities.Spectrum` - -```python -class Spectrum(Modality): - """ - Astronomical spectrum. - - Attributes: - wavelength: Wavelength array in Angstroms - flux: Flux density array - ivar: Optional inverse variance - survey: Source survey identifier - """ - - wavelength: np.ndarray - flux: np.ndarray - ivar: Optional[np.ndarray] = None - survey: Optional[str] = None - - def resample( - self, - new_wavelength: np.ndarray - ) -> 'Spectrum': - """Resample spectrum to new wavelength grid.""" - - def normalize(self) -> 'Spectrum': - """Apply median normalization.""" -``` - -### Scalar Modalities - -AION-1 includes numerous scalar modalities for photometry, shapes, and physical parameters: - -#### Photometric Fluxes - -```python -class FluxG(ScalarModality): - """g-band flux measurement.""" - value: np.ndarray - error: Optional[np.ndarray] = None - -class FluxR(ScalarModality): - """r-band flux measurement.""" - value: np.ndarray - error: Optional[np.ndarray] = None - -class FluxI(ScalarModality): - """i-band flux measurement.""" - value: np.ndarray - error: Optional[np.ndarray] = None - -class FluxZ(ScalarModality): - """z-band flux measurement.""" - value: np.ndarray - error: Optional[np.ndarray] = None ``` -#### Shape Parameters +## Codec Management -```python -class E1(ScalarModality): - """First ellipticity component.""" - value: np.ndarray - -class E2(ScalarModality): - """Second ellipticity component.""" - value: np.ndarray - -class RadiusCARP(ScalarModality): - """CARP radius measurement.""" - value: np.ndarray -``` +### `aion.codecs.CodecManager` -#### Physical Properties +Manages automatic loading and application of modality-specific codecs. ```python -class Redshift(ScalarModality): - """Spectroscopic or photometric redshift.""" - value: np.ndarray - error: Optional[np.ndarray] = None - -class ExtinctionV(ScalarModality): - """V-band extinction.""" - value: np.ndarray - -class Parallax(ScalarModality): - """Parallax measurement in mas.""" - value: np.ndarray - error: Optional[np.ndarray] = None -``` - -### Catalog Modalities +from aion.codecs import CodecManager -#### `aion.modalities.Catalog` - -```python -class Catalog(Modality): +class CodecManager: """ - Astronomical object catalog. - - Attributes: - entries: List of catalog objects - max_objects: Maximum number of objects to process + Central manager for encoding/decoding between modalities and tokens. """ - entries: List[CatalogEntry] - max_objects: int = 100 - - def sort_by_distance(self) -> 'Catalog': - """Sort entries by distance from center.""" - - def filter_bright(self, magnitude_limit: float) -> 'Catalog': - """Filter to objects brighter than limit.""" -``` - -## Codecs (Tokenizers) - -Codecs convert between modalities and discrete tokens. Each modality type has a specialized codec. - -### Base Codec Interface - -#### `aion.codecs.base.Codec` - -```python -class Codec(ABC): - """ - Abstract base class for modality codecs. - """ - - @abstractmethod - def encode(self, modality: Modality) -> torch.Tensor: - """Encode modality to discrete tokens.""" - - @abstractmethod - def decode(self, tokens: torch.Tensor) -> Modality: - """Decode tokens back to modality.""" - - @classmethod - def from_pretrained(cls, path: str) -> 'Codec': - """Load pre-trained codec.""" - - def save_pretrained(self, path: str): - """Save codec weights and configuration.""" -``` - -### Image Codec - -#### `aion.codecs.ImageCodec` - -```python -class ImageCodec(Codec): - """ - Image tokenizer using MagVit architecture. - - Supports multi-survey images with different band counts - through a unified channel embedding scheme. - """ - - def __init__( - self, - hidden_dim: int = 512, - n_embed: int = 10000, - compression_levels: int = 2, - quantizer: str = 'fsq' - ): + def __init__(self, device: str = 'cuda'): """ - Initialize image codec. + Initialize codec manager. Args: - hidden_dim: Hidden dimension size - n_embed: Codebook size - compression_levels: Spatial compression factor - quantizer: Quantization method ('fsq' or 'vq') - """ - - def preprocess( - self, - image: Image, - crop_size: int = 96 - ) -> torch.Tensor: - """Apply survey-specific preprocessing.""" - - def get_latent_shape( - self, - image_shape: Tuple[int, ...] - ) -> Tuple[int, ...]: - """Get shape of latent representation.""" -``` + device: Device to load codecs on ('cuda', 'cpu') -### Spectrum Codec - -#### `aion.codecs.SpectrumCodec` - -```python -class SpectrumCodec(Codec): - """ - Spectrum tokenizer using ConvNeXt V2 architecture. - - Uses a shared latent wavelength grid to handle spectra - from different instruments. - """ - - def __init__( - self, - latent_wavelength: np.ndarray, - hidden_dims: List[int] = [96, 192, 384, 768], - n_embed: int = 1024, - quantizer: str = 'lfq' - ): + Example: + >>> codec_manager = CodecManager(device='cuda') """ - Initialize spectrum codec. - Args: - latent_wavelength: Target wavelength grid - hidden_dims: ConvNeXt stage dimensions - n_embed: Codebook size - quantizer: Quantization method + def encode(self, *modalities) -> Dict[str, torch.Tensor]: """ + Encode modalities into discrete tokens. - def to_latent_grid( - self, - spectrum: Spectrum - ) -> torch.Tensor: - """Interpolate spectrum to latent wavelength grid.""" -``` - -### Scalar Codec + Args: + *modalities: Variable number of modality objects -#### `aion.codecs.ScalarCodec` + Returns: + Dictionary mapping token keys to token tensors -```python -class ScalarCodec(Codec): - """ - Tokenizer for scalar quantities using adaptive quantization. - """ + Example: + >>> tokens = codec_manager.encode(image, spectrum, flux_g) + >>> # Returns: {"tok_image": tensor(...), "tok_spectrum_sdss": tensor(...), "tok_flux_g": tensor(...)} + """ - def __init__( + def decode( self, - quantizer_type: str = 'reservoir', - n_bins: int = 256 + tokens: Dict[str, torch.Tensor], + modality_class: type, + **metadata ): """ - Initialize scalar codec. + Decode tokens back to modality objects. Args: - quantizer_type: Type of quantizer - - 'linear': Uniform bins - - 'log': Logarithmic bins - - 'reservoir': Learned adaptive bins - - 'compressed': Transform then quantize - n_bins: Number of quantization levels - """ + tokens: Dictionary of token tensors + modality_class: Class of modality to decode (e.g., LegacySurveyImage) + **metadata: Additional metadata required for reconstruction - def fit(self, values: np.ndarray): - """Fit quantizer to data distribution.""" + Returns: + Reconstructed modality object + + Example: + >>> reconstructed = codec_manager.decode( + ... tokens, + ... LegacySurveyImage, + ... bands=["DES-G", "DES-R", "DES-I", "DES-Z"] + ... ) + """ ``` -## Quantizers +## Modalities -Quantization modules that convert continuous values to discrete tokens. +AION-1 uses a typed modality system to ensure data compatibility and provenance tracking. -### `aion.codecs.quantizers.FSQ` +### Base Classes ```python -class FiniteScalarQuantization(nn.Module): - """ - Finite Scalar Quantization from MagVit. +from aion.modalities import BaseModality - Factorizes codebook into multiple small codebooks for - better gradient flow and training stability. - """ +class BaseModality: + """Base class for all astronomical modalities.""" - def __init__( - self, - levels: List[int] = [8, 5, 5, 5, 5], - eps: float = 1e-3 - ): - """ - Args: - levels: Number of levels per dimension - eps: Small constant for numerical stability - """ + @property + def token_key(self) -> str: + """Unique identifier for this modality type in the model.""" ``` -### `aion.codecs.quantizers.LFQ` +### Image Modalities ```python -class LookupFreeQuantization(nn.Module): - """ - Lookup-Free Quantization using entropy regularization. +from aion.modalities import LegacySurveyImage, HSCImage - Achieves quantization without explicit codebook lookup, - improving training efficiency. +class LegacySurveyImage(BaseModality): """ + Legacy Survey multi-band image. - def __init__( - self, - dim: int, - codebook_size: int, - entropy_weight: float = 0.1 - ): - """ - Args: - dim: Embedding dimension - codebook_size: Target vocabulary size - entropy_weight: Entropy regularization weight - """ -``` - -## Preprocessing - -Survey-specific preprocessing utilities. - -### `aion.codecs.preprocessing.ImagePreprocessor` - -```python -class ImagePreprocessor: - """ - Survey-specific image preprocessing. + Attributes: + flux: Image tensor with shape [batch, 4, height, width] for g,r,i,z bands + bands: List of band identifiers (e.g., ['DES-G', 'DES-R', 'DES-I', 'DES-Z']) """ - def __init__(self, survey: str): - """ - Initialize for specific survey. - - Args: - survey: Survey name ('HSC', 'DES', 'SDSS', etc.) - """ - - def __call__(self, image: Image) -> torch.Tensor: - """Apply preprocessing pipeline.""" - - def get_rescaling_params(self) -> Dict[str, float]: - """Get survey-specific rescaling parameters.""" -``` + flux: torch.Tensor + bands: List[str] -### `aion.codecs.preprocessing.SpectrumPreprocessor` + @property + def token_key(self) -> str: + return "tok_image" -```python -class SpectrumPreprocessor: - """ - Spectrum normalization and preprocessing. +class HSCImage(BaseModality): """ + HSC multi-band image. - def normalize_median( - self, - spectrum: Spectrum - ) -> Spectrum: - """Apply median normalization.""" - - def mask_skylines( - self, - spectrum: Spectrum - ) -> Spectrum: - """Mask common sky emission lines.""" -``` - -## Model Components - -### `aion.fourm.FourM` - -```python -class FourM(nn.Module): + Attributes: + flux: Image tensor with shape [batch, 5, height, width] for g,r,i,z,y bands + bands: List of band identifiers """ - Base multimodal transformer architecture. - Implements the encoder-decoder architecture with - modality-specific embeddings and flexible attention. - """ + flux: torch.Tensor + bands: List[str] - def __init__( - self, - encoder_depth: int = 12, - decoder_depth: int = 12, - dim: int = 768, - num_heads: int = 12, - mlp_ratio: float = 4.0, - use_bias: bool = False - ): - """Initialize FourM architecture.""" + @property + def token_key(self) -> str: + return "tok_image" ``` -### `aion.fourm.encoder_embeddings` +### Spectrum Modalities ```python -class ModalityEmbedding(nn.Module): - """ - Learnable embeddings for each modality type. +from aion.modalities import DESISpectrum, SDSSSpectrum - Provides both modality identification and survey - provenance information. +class DESISpectrum(BaseModality): """ + DESI spectroscopic observation. - def __init__( - self, - num_modalities: int, - num_surveys: int, - embed_dim: int - ): - """Initialize modality embeddings.""" -``` - -## Utilities + Attributes: + flux: Flux density array + ivar: Inverse variance array + mask: Boolean mask array + wavelength: Wavelength array in Angstroms + """ -### `aion.model_utils` + flux: torch.Tensor + ivar: torch.Tensor + mask: torch.Tensor + wavelength: torch.Tensor -```python -def load_codec(modality: str, device: str = 'cuda') -> Codec: - """Load pre-trained codec for modality.""" + @property + def token_key(self) -> str: + return "tok_spectrum_desi" -def create_model_config( - model_size: str = 'base' -) -> Dict[str, Any]: - """Get configuration for model size.""" +class SDSSSpectrum(BaseModality): + """SDSS spectroscopic observation.""" -def count_parameters(model: nn.Module) -> int: - """Count trainable parameters in model.""" + @property + def token_key(self) -> str: + return "tok_spectrum_sdss" ``` -### `aion.generation_utils` +### Scalar Modalities ```python -def sample_with_temperature( - logits: torch.Tensor, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None -) -> torch.Tensor: - """ - Sample from logits with temperature scaling. - - Args: - logits: Model output logits - temperature: Sampling temperature - top_k: Top-k filtering - top_p: Nucleus sampling threshold - - Returns: - Sampled token indices - """ +from aion.modalities import ( + LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ, + Z, GaiaParallax +) -def generate_with_caching( - model: AION, - inputs: Dict[str, torch.Tensor], - max_length: int, - use_cache: bool = True -) -> torch.Tensor: - """Generate tokens with KV caching for efficiency.""" -``` +class LegacySurveyFluxG(BaseModality): + """Legacy Survey g-band flux measurement.""" -## Data Loading + value: torch.Tensor -### `aion.data.AstronomicalDataset` + @property + def token_key(self) -> str: + return "tok_flux_g" -```python -class AstronomicalDataset(Dataset): - """ - PyTorch dataset for astronomical observations. - """ +class Z(BaseModality): + """Spectroscopic redshift.""" - def __init__( - self, - data_paths: List[str], - modalities: List[str], - transform: Optional[Callable] = None - ): - """ - Initialize dataset. - - Args: - data_paths: Paths to data files - modalities: List of modalities to load - transform: Optional data transformation - """ + value: torch.Tensor - def __getitem__(self, idx: int) -> Dict[str, Modality]: - """Get single observation.""" + @property + def token_key(self) -> str: + return "tok_z" ``` -## Example Usage +## Complete Usage Example -### Complete Pipeline +Here's a comprehensive example showing the full workflow: ```python import torch from aion import AION -from aion.modalities import Image, Spectrum -from aion.codecs import ImageCodec, SpectrumCodec +from aion.codecs import CodecManager +from aion.modalities import ( + LegacySurveyImage, DESISpectrum, + LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ +) + +# 1. Load model and codec manager +model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval() +codec_manager = CodecManager(device='cuda') -# Load model and codecs -model = AION.from_pretrained('polymathic-ai/aion-base') -image_codec = ImageCodec.from_pretrained('polymathic-ai/aion-image-codec') -spectrum_codec = SpectrumCodec.from_pretrained('polymathic-ai/aion-spectrum-codec') +# 2. Prepare data +image = LegacySurveyImage( + flux=torch.tensor(image_data, dtype=torch.float32), + bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] +) + +spectrum = DESISpectrum( + flux=torch.tensor(flux_data), + ivar=torch.tensor(ivar_data), + mask=torch.tensor(mask_data, dtype=torch.bool), + wavelength=torch.tensor(wavelength_data) +) -# Load data -image = Image(flux=galaxy_flux, bands=['g', 'r', 'i', 'z', 'y']) -spectrum = Spectrum(wavelength=wavelength, flux=flux) +flux_g = LegacySurveyFluxG(value=torch.tensor([flux_g_value])) -# Tokenize -tokens = { - 'image': image_codec.encode(image), - 'spectrum': spectrum_codec.encode(spectrum) -} +# 3. Encode to tokens +tokens = codec_manager.encode(image, spectrum, flux_g) -# Encode to representations +# 4. Extract embeddings for downstream tasks with torch.no_grad(): - representations = model.encode(tokens) + embeddings = model.encode(tokens, num_encoder_tokens=600) + pooled_embeddings = embeddings.mean(dim=1) # [batch, hidden_dim] -# Generate missing modalities -results = model.generate( - inputs={'image': image}, - targets=['spectrum', 'redshift'] +# 5. Predict redshift +with torch.no_grad(): + predictions = model( + tokens, + target_mask={"tok_z": torch.zeros(1, 1)}, + num_encoder_tokens=600 + ) + redshift_probs = torch.softmax(predictions["tok_z"][0], dim=-1) + +# 6. Decode tokens back to modalities +reconstructed_image = codec_manager.decode( + tokens, + LegacySurveyImage, + bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] ) - -# Decode results -generated_spectrum = spectrum_codec.decode(results['spectrum']) -print(f"Predicted redshift: {results['redshift'].value[0]:.3f}") ``` -## Error Handling +## Model Variants -All AION components include comprehensive error handling: +Currently available pre-trained models: -```python -from aion.exceptions import ( - ModalityError, # Invalid modality data - CodecError, # Tokenization failures - ModelError, # Model inference errors - DataError # Data loading issues -) +| Model | Parameters | HuggingFace ID | +|-------|------------|----------------| +| AION-Base | 300M | `polymathic-ai/aion-base` | -try: - result = model.generate(inputs, targets) -except ModalityError as e: - print(f"Invalid modality: {e}") -except CodecError as e: - print(f"Tokenization failed: {e}") -``` +More model variants will be added as they become available. + +## Common Patterns -## Performance Tips +### Similarity Search +```python +def compute_similarities(query_tokens, database_tokens, model): + """Compute embedding similarities between query and database.""" + with torch.no_grad(): + query_emb = model.encode(query_tokens).mean(dim=1) + db_embs = model.encode(database_tokens).mean(dim=1) + + from sklearn.metrics.pairwise import cosine_similarity + return cosine_similarity(query_emb.cpu(), db_embs.cpu()) +``` -1. **Batch Processing**: Always process multiple objects together when possible -2. **Mixed Precision**: Use `torch.cuda.amp` for faster inference -3. **Token Caching**: Reuse encoder outputs when generating multiple targets -4. **Device Placement**: Use `.to(device)` consistently for all tensors +### Batch Processing +```python +def process_batch(batch_data, model, codec_manager): + """Process a batch of astronomical objects.""" + batch_tokens = codec_manager.encode(*batch_data) -For more details, see the [Usage Guide](usage.md) and [Architecture](architecture.md) documentation. + with torch.no_grad(): + embeddings = model.encode(batch_tokens, num_encoder_tokens=600) -```{eval-rst} -.. automodule:: aion - :members: - :undoc-members: - :show-inheritance: + return embeddings.mean(dim=1) # Pooled embeddings ``` + +For more examples, see the [Usage Guide](usage.md) and [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb). diff --git a/docs/architecture.md b/docs/architecture.md index 607c7f8..f4062cd 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -222,9 +222,9 @@ AION-1 comes in three sizes, each using the same architecture with different dim | Model | Parameters | Encoder Layers | Decoder Layers | Hidden Dim | Attention Heads | |-------|------------|----------------|----------------|------------|-----------------| -| AION-1-B (Base) | 300M | 12 | 12 | 768 | 12 | -| AION-1-L (Large) | 800M | 24 | 24 | 1024 | 16 | -| AION-1-XL (XLarge) | 3.1B | 24 | 24 | 2048 | 32 | +| AION-Base | ~300M | 12 | 12 | 768 | 12 | + +> **Note**: Additional model sizes may be released in the future. Current model ID: `polymathic-ai/aion-base` All models use: - SwiGLU activation functions diff --git a/docs/index.md b/docs/index.md index 0d44812..d414bae 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,65 +3,27 @@

AION-1

AstronomIcal Omnimodal Network

-

The first large-scale multimodal foundation model for astronomy

+

Large-Scale Multimodal Foundation Model for Astronomy

Get Started โ†’ - Read the Paper - Run on Colab + + Run on Colab
``` -# Welcome to AION-1 - -AION-1 (AstronomIcal Omnimodal Network) represents a breakthrough in astronomical machine learning: the first foundation model capable of understanding and processing arbitrary combinations of astronomical observations across 39 different data modalities. Trained on over 200 million astronomical objects, AION-1 unifies imaging, spectroscopy, photometry, and catalog data from major ground- and space-based observatories into a single, powerful framework. +# AION-1 Documentation ## ๐ŸŒŸ Why AION-1? -Traditional approaches in astronomy treat each data modality in isolation, missing the rich interconnections between different types of observations. AION-1 fundamentally changes this paradigm by: +Trained on over 200 million astronomical objects, AION-1 (AstronomIcal Omnimodal Network) is the first Foundation Model capable of unifying multiband imaging, spectroscopy, and photometry from major ground- and space-based observatories into a single framework. -- **Learning Cross-Modal Relationships**: The model discovers how different observations relate to each other, building a deep understanding of the underlying astrophysical objects +Compared to traditional machine learning approaches in Astronomy, AION-1 stands out on several points: - **Enabling Flexible Data Fusion**: Scientists can use any combination of available observations without redesigning their analysis pipeline +- **Enabling Easy Adaptation to Downstream Tasks**: Scientists can adapt AION-1 to new tasks in a matter of minutes and reach SOTA performance - **Excelling in Low-Data Regimes**: AION-1 achieves competitive results with orders of magnitude less labeled data than supervised approaches - **Providing Universal Representations**: The learned embeddings capture physically meaningful structure useful across diverse downstream tasks -## ๐Ÿ“Š Key Capabilities - -```{eval-rst} -.. grid:: 1 1 2 3 - :gutter: 3 - - .. grid-item-card:: ๐ŸŒŒ 39 Data Modalities - :class-card: feature-card - - Seamlessly integrates multiband images, optical spectra, photometry, and catalog data from HSC, Legacy Survey, SDSS, DESI, and Gaia - - .. grid-item-card:: ๐Ÿง  200M+ Objects - :class-card: feature-card - - Pre-trained on massive astronomical datasets spanning galaxies, stars, and quasars across multiple surveys - - .. grid-item-card:: ๐Ÿ”ง Flexible Architecture - :class-card: feature-card - - Two-stage design with modality-specific tokenization followed by transformer-based multimodal masked modeling - - .. grid-item-card:: โšก Emergent Behaviors - :class-card: feature-card - - Demonstrates physical understanding, superior low-data performance, and meaningful latent space organization - - .. grid-item-card:: ๐ŸŽฏ Versatile Applications - :class-card: feature-card - - Supports regression, classification, generation, retrieval, and cross-modal prediction tasks out-of-the-box - - .. grid-item-card:: ๐ŸŒ Open Science - :class-card: feature-card - - Fully open-source including datasets, training scripts, and model weights for reproducible research -``` - ## ๐Ÿš€ Quick Start Getting started with AION-1 is straightforward: @@ -69,44 +31,42 @@ Getting started with AION-1 is straightforward: ```python # Minimal end-to-end example from aion import AION -import numpy as np +from aion.codecs import CodecManager +from aion.modalities import (LegacySurveyImage, LegacySurveyFluxG, +LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ) + +# 1) Load a pre-trained checkpoint (300 M parameters) +model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval() +codec_manager = CodecManager(device='cuda') # Manages codecs for each modality + +# 2) Prepare demo inputs (96ร—96 g,r,i,z cut-out and photometry) +# Create image modality +image = LegacySurveyImage( + flux=data["legacysurvey_image_flux"], + bands=["DES-G", "DES-R", "DES-I", "DES-Z"], +) -# 1) Load a pre-trained checkpoint (800 M parameters) -model = AION.from_pretrained('polymathic-ai/aion-base') +# Create flux modalities +g = LegacySurveyFluxG(value=data["legacysurvey_FLUX_G"]) +r = LegacySurveyFluxR(value=data["legacysurvey_FLUX_R"]) +i = LegacySurveyFluxI(value=data["legacysurvey_FLUX_I"]) +z = LegacySurveyFluxZ(value=data["legacysurvey_FLUX_Z"]) -# 2) Prepare demo inputs (96ร—96 HSC g,r,i,z,y cut-out and SDSS spectrum) -galaxy_image = np.load('hsc_cutout_5band.npy') # shape (5,96,96) -galaxy_spectrum = np.load('sdss_spectrum.npy') # dict with wavelength/flux +# Encode input modalities into tokens +tokens = codec_manager.encode(image, g, r, i, z) -# 3) Generate a high-resolution DESI-like spectrum from the image -generated = model.generate( - inputs={'image': galaxy_image}, - targets=['spectrum'] +# 3) Generate a redshift distribution from these set of inputs +predictions = model( + tokens, + target_mask={"tok_z": torch.zeros(batch_size, 1)}, + num_encoder_tokens=600 ) +redshift_logits = predictions["tok_z"] # Shape: [batch, sequence, vocab_size] # 4) Extract joint embeddings for downstream use -embeddings = model.encode({'image': galaxy_image, 'spectrum': galaxy_spectrum}) +embeddings = model.encode(tokens, num_encoder_tokens=600) # Shape: [batch, seq_len, hidden_dim] ``` -## ๐Ÿ”ฌ Scientific Impact - -AION-1 demonstrates several emergent behaviors that reflect its deep understanding of astronomical data: - -### Physical Understanding -- Solves non-trivial scientific tasks using only simple linear probes on learned representations -- Organizes objects in embedding space along physically meaningful dimensions -- Captures relationships between disparate observations of the same physical phenomena - -### Performance Advantages -- Achieves state-of-the-art results on galaxy property estimation, stellar parameter prediction, and morphology classification -- Outperforms supervised baselines by 3x on rare object detection tasks -- Enables accurate cross-modal prediction even for modality pairs never seen during training - -### Practical Benefits -- Reduces data requirements by orders of magnitude for downstream tasks -- Enables seamless integration of heterogeneous observations -- Provides robust uncertainty quantification through multiple sampling - ## ๐Ÿ“š Documentation Overview ```{eval-rst} @@ -119,11 +79,11 @@ AION-1 demonstrates several emergent behaviors that reflect its deep understandi Environment setup, dependencies, and configuration - .. grid-item-card:: Model Architecture + .. grid-item-card:: Model Specifications :link: architecture.html :class-card: doc-card - Deep dive into tokenization, transformers, and design + Deep dive into tokenization, transformers, and trarining data .. grid-item-card:: Usage Guide :link: usage.html @@ -147,13 +107,3 @@ architecture usage api ``` - -## ๐Ÿค Join the Community - -```{raw} html -
-

Advancing astronomical AI together

-

AION-1 is developed by Polymathic AI in collaboration with the Flatiron Institute and leading astronomical institutions worldwide. We welcome contributions from astronomers, ML researchers, and data scientists interested in pushing the boundaries of multimodal scientific machine learning.

- Start Contributing โ†’ -
-``` diff --git a/docs/installation.md b/docs/installation.md index f28e27c..6225e4d 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,89 +1,111 @@ # Installation Guide -This comprehensive guide will walk you through installing AION-1 and setting up your environment for astronomical multimodal analysis. +Quick and straightforward installation guide for AION-1. ## System Requirements ### Hardware Requirements -AION-1 is designed to run efficiently on various hardware configurations: +**Minimum (CPU only)**: +- 16 GB RAM +- 20 GB free storage -- **Minimum Requirements**: - - CPU: 4+ cores (Intel/AMD x86_64 or Apple Silicon) - - RAM: 16 GB - - GPU: NVIDIA GPU with 8GB+ VRAM (optional but recommended) - - Storage: 50 GB free space for models and data +**Recommended (GPU)**: +- NVIDIA GPU with 8GB+ VRAM +- 32 GB RAM +- 50 GB free storage -- **Recommended Requirements**: - - CPU: 8+ cores - - RAM: 32 GB or more - - GPU: NVIDIA GPU with 24GB+ VRAM (e.g., RTX 3090, A5000, or better) - - Storage: 100 GB+ free space - -- **For Large-Scale Processing**: - - Multiple GPUs with NVLink - - 64GB+ RAM - - Fast SSD storage for data loading +**For Large-Scale Processing**: +- NVIDIA GPU with 24GB+ VRAM (e.g., RTX 4090, A5000+) +- 64GB+ RAM ### Software Requirements -- Python 3.10 or later -- CUDA 11.8+ (for GPU support) -- Operating System: Linux, macOS, or Windows - -## Installation Methods +- Python 3.10+ +- CUDA 11.8+ (for GPU acceleration) +- Linux, macOS, or Windows -### 1. Quick Install via PyPI +## Installation -The simplest way to install AION-1 is through PyPI: +### Quick Install (Recommended) ```bash +# Install PyTorch with CUDA support (adjust CUDA version as needed) +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 + +# Install AION pip install aion ``` -This installs the core AION package with minimal dependencies. - -### 2. Full Installation with PyTorch - -For GPU support and optimal performance: +### Alternative: CPU-only Installation ```bash -# Install PyTorch first (adjust for your CUDA version) -pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 +# Install CPU-only PyTorch +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu -# Then install AION +# Install AION pip install aion ``` -### 3. Development Installation - -For contributors or those who want the latest features: +### Development Installation ```bash -# Clone the repository git clone https://github.com/polymathic-ai/aion.git cd aion +pip install -e ".[torch,dev]" +``` + +## Verification + +Test your installation: -# Create a virtual environment -python -m venv venv -source venv/bin/activate # On Windows: venv\Scripts\activate +```python +import torch +from aion import AION +from aion.codecs import CodecManager -# Install in development mode -pip install -e ".[dev]" +print(f"PyTorch version: {torch.__version__}") +print(f"CUDA available: {torch.cuda.is_available()}") + +# Test model loading (requires internet connection) +try: + model = AION.from_pretrained('polymathic-ai/aion-base') + print("โœ“ AION model loaded successfully") +except Exception as e: + print(f"โœ— Model loading failed: {e}") + +# Test codec manager +try: + codec_manager = CodecManager(device='cuda' if torch.cuda.is_available() else 'cpu') + print("โœ“ CodecManager initialized successfully") +except Exception as e: + print(f"โœ— CodecManager failed: {e}") ``` -## Setting Up Your Environment +## Troubleshooting -### 1. Virtual Environment Setup +### Common Issues -We strongly recommend using a virtual environment: +**CUDA out of memory**: +```bash +# Use smaller model or CPU +model = AION.from_pretrained('polymathic-ai/aion-base').to('cpu') +``` +**HuggingFace connection issues**: ```bash -# Using venv -python -m venv aion-env -source aion-env/bin/activate # On Windows: aion-env\Scripts\activate +# Set up HuggingFace cache directory +export HF_HOME=/path/to/cache +``` -# Using conda -conda create -n aion python=3.10 -conda activate aion +**Import errors**: +```bash +# Reinstall with fresh environment +pip uninstall aion torch +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 +pip install aion ``` + +## Next Steps + +Once installed, try the [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb) or check the [Usage Guide](usage.md) for examples. diff --git a/docs/usage.md b/docs/usage.md index 5d03545..855dbea 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -1,723 +1,609 @@ # AION-1 Usage Guide -This comprehensive guide demonstrates how to use AION-1 for various astronomical analysis tasks. From basic inference to advanced multimodal generation, you'll learn to leverage AION-1's capabilities for your research. +This comprehensive guide demonstrates how to use AION-1 for various astronomical analysis tasks, based on the actual working implementation. ## Table of Contents 1. [Quick Start](#quick-start) -2. [Loading and Preprocessing Data](#loading-and-preprocessing-data) -3. [Basic Inference](#basic-inference) -4. [Multimodal Generation](#multimodal-generation) -5. [Cross-Modal Translation](#cross-modal-translation) -6. [Representation Learning](#representation-learning) -7. [Advanced Applications](#advanced-applications) -8. [Performance Optimization](#performance-optimization) +2. [Loading and Preparing Data](#loading-and-preparing-data) +3. [Basic Workflows](#basic-workflows) +4. [Embedding Extraction](#embedding-extraction) +5. [Similarity Search](#similarity-search) +6. [Property Prediction](#property-prediction) +7. [Performance Tips](#performance-tips) ## Quick Start -Let's begin with a simple example that showcases AION-1's core capabilities: +Here's how to get started with AION-1 in just a few lines: ```python -import torch, numpy as np +import torch +import numpy as np from aion import AION -from aion.modalities import Image +from aion.codecs import CodecManager +from aion.modalities import LegacySurveyImage + +# 1. Load model and codec manager +model = AION.from_pretrained('polymathic-ai/aion-base').to('cuda').eval() +codec_manager = CodecManager(device='cuda') -# 1) Load a checkpoint (300 M parameters) -model = AION.from_pretrained('polymathic-ai/aion-tiny').eval() +# 2. Prepare your astronomical data +image = LegacySurveyImage( + flux=torch.tensor(your_image_data, dtype=torch.float32), # Shape: [batch, 4, height, width] + bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] +) -# 2) Read an example 5-band HSC cut-out (units: nanomaggies) -flux_cube = np.load('hsc_cutout_5band.npy') # shape (5,96,96) -img = Image(flux=flux_cube, bands=['HSC-G','HSC-R','HSC-I','HSC-Z','HSC-Y']) +# 3. Encode to tokens +tokens = codec_manager.encode(image) -# 3) Predict an SDSS-like spectrum (observer-frame, erg sโปยน cmโปยฒ ร…โปยน) -with torch.inference_mode(): - result = model.generate(inputs={'image': img}, targets=['spectrum']) +# 4. Extract embeddings for downstream analysis +with torch.no_grad(): + embeddings = model.encode(tokens, num_encoder_tokens=600) + # Shape: [batch, sequence_length, 768] -spec = result['spectrum'] -print(f"Generated spectrum: ฮป range {spec.wavelength[0]:.0f}-{spec.wavelength[-1]:.0f} ร…, shape={spec.flux.shape}") +# 5. Predict redshift distribution +with torch.no_grad(): + predictions = model( + tokens, + target_mask={"tok_z": torch.zeros(batch_size, 1)}, + num_encoder_tokens=600 + ) + redshift_logits = predictions["tok_z"] + redshift_probs = torch.softmax(redshift_logits, dim=-1) ``` -## Loading and Preprocessing Data +## Loading and Preparing Data ### Working with Images -AION-1 expects images in a specific format. Here's how to prepare astronomical images: +AION-1 expects multi-band astronomical images with specific formatting: ```python -import numpy as np +import torch from astropy.io import fits -from aion.modalities import Image -from aion.codecs.preprocessing import ImagePreprocessor - -# Load FITS data -with fits.open('galaxy.fits') as hdul: - # Assuming multi-band data in extensions - flux_data = np.array([hdul[i].data for i in range(1, 6)]) # 5 bands - -# Create Image modality -image = Image( - flux=flux_data, - bands=['HSC-G', 'HSC-R', 'HSC-I', 'HSC-Z', 'HSC-Y'], - # Optional: provide inverse variance for optimal processing - ivar=inverse_variance_data -) +from aion.modalities import LegacySurveyImage, HSCImage + +# Example 1: Legacy Survey (4-band: g,r,i,z) +def load_legacy_survey_image(fits_path): + """Load and format Legacy Survey FITS data.""" + with fits.open(fits_path) as hdul: + # Assuming bands are in separate extensions + flux_data = np.array([hdul[i].data for i in range(1, 5)]) # 4 bands + + image = LegacySurveyImage( + flux=torch.tensor(flux_data, dtype=torch.float32), + bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] + ) + return image + +# Example 2: HSC (5-band: g,r,i,z,y) +def load_hsc_image(flux_array): + """Load HSC 5-band image data.""" + image = HSCImage( + flux=torch.tensor(flux_array, dtype=torch.float32), + bands=['HSC-G', 'HSC-R', 'HSC-I', 'HSC-Z', 'HSC-Y'] + ) + return image -# Apply survey-specific preprocessing -preprocessor = ImagePreprocessor(survey='HSC') -processed_image = preprocessor(image) +# Note: AION-1 automatically crops/pads images to 96x96 pixels ``` ### Working with Spectra -Load and prepare spectroscopic data: +Load and prepare spectroscopic observations: ```python -from aion.modalities import Spectrum -from astropy.io import fits - -# Load SDSS spectrum -hdul = fits.open('spec-plate-mjd-fiber.fits') -wavelength = 10**hdul[1].data['loglam'] # Convert log wavelength -flux = hdul[1].data['flux'] -ivar = hdul[1].data['ivar'] - -# Create Spectrum modality -spectrum = Spectrum( - wavelength=wavelength, - flux=flux, - ivar=ivar, - survey='SDSS' -) - -# The model handles resampling to internal wavelength grid automatically +from aion.modalities import DESISpectrum, SDSSSpectrum + +def load_desi_spectrum(flux, ivar, mask, wavelength): + """Load DESI spectrum data.""" + spectrum = DESISpectrum( + flux=torch.tensor(flux, dtype=torch.float32), + ivar=torch.tensor(ivar, dtype=torch.float32), + mask=torch.tensor(mask, dtype=torch.bool), + wavelength=torch.tensor(wavelength, dtype=torch.float32) + ) + return spectrum + +def load_sdss_spectrum_from_fits(fits_path): + """Load SDSS spectrum from FITS file.""" + with fits.open(fits_path) as hdul: + data = hdul[1].data + wavelength = 10**data['loglam'] # Convert from log wavelength + flux = data['flux'] + ivar = data['ivar'] + + # Create mask for bad pixels + mask = (ivar > 0) & (flux > 0) + + spectrum = SDSSSpectrum( + flux=torch.tensor(flux, dtype=torch.float32), + ivar=torch.tensor(ivar, dtype=torch.float32), + mask=torch.tensor(mask, dtype=torch.bool), + wavelength=torch.tensor(wavelength, dtype=torch.float32) + ) + return spectrum ``` -### Working with Catalog Data +### Working with Photometric Data -Process tabular astronomical measurements: +Prepare scalar measurements like fluxes and shape parameters: ```python from aion.modalities import ( - FluxG, FluxR, FluxI, FluxZ, - E1, E2, RadiusCARP, Redshift + LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ, + Z, GaiaParallax ) -# Load catalog data (e.g., from pandas DataFrame) -catalog_entry = { - 'flux_g': FluxG(value=catalog_df['flux_g'].values), - 'flux_r': FluxR(value=catalog_df['flux_r'].values), - 'e1': E1(value=catalog_df['e1'].values), - 'e2': E2(value=catalog_df['e2'].values), - 'radius': RadiusCARP(value=catalog_df['radius'].values) -} -``` +def create_photometry_modalities(catalog_data): + """Create modalities from catalog measurements.""" + modalities = [] -## Basic Inference + # Photometric fluxes + if 'flux_g' in catalog_data: + modalities.append(LegacySurveyFluxG( + value=torch.tensor(catalog_data['flux_g'], dtype=torch.float32) + )) -### Single Modality Prediction + if 'flux_r' in catalog_data: + modalities.append(LegacySurveyFluxR( + value=torch.tensor(catalog_data['flux_r'], dtype=torch.float32) + )) -Predict missing photometric measurements from available data: + # Redshift + if 'redshift' in catalog_data: + modalities.append(Z( + value=torch.tensor(catalog_data['redshift'], dtype=torch.float32) + )) -```python -# Given g,r,i bands, predict z band -inputs = { - 'flux_g': FluxG(value=[19.5]), - 'flux_r': FluxR(value=[18.2]), - 'flux_i': FluxI(value=[17.8]) -} - -# Predict z-band flux -with torch.no_grad(): - predictions = model.generate( - inputs=inputs, - targets=['flux_z'] - ) - -z_flux = predictions['flux_z'].value[0] -print(f"Predicted z-band flux: {z_flux:.2f}") + return modalities ``` -### Batch Processing - -Process multiple objects efficiently: +## Basic Workflows -```python -# Prepare batch of galaxies -batch_images = [load_galaxy(i) for i in range(32)] -batch = { - 'image': Image.batch(batch_images) -} +### Workflow 1: Embedding Extraction -# Generate properties for all galaxies -with torch.no_grad(): - results = model.generate( - inputs=batch, - targets=['redshift', 'e1', 'e2', 'radius'] - ) - -# Extract results -redshifts = results['redshift'].value -ellipticities = np.sqrt(results['e1'].value**2 + results['e2'].value**2) -``` - -## Multimodal Generation - -### Conditional Generation - -Generate multiple modalities conditioned on partial observations: +Extract learned representations for downstream machine learning: ```python -# Complex multimodal generation example -def analyze_galaxy(image_path, known_redshift=None): - # Load image - image = load_and_preprocess_image(image_path) - - inputs = {'image': image} - if known_redshift: - inputs['redshift'] = Redshift(value=[known_redshift]) - - # Generate comprehensive analysis - targets = [ - 'spectrum', # Full spectrum - 'flux_g', 'flux_r', 'flux_i', 'flux_z', # Photometry - 'e1', 'e2', # Shape parameters - 'radius', # Size - 'parallax', # Distance indicator - 'extinction_v' # Dust extinction - ] +def extract_galaxy_embeddings(data_list, model, codec_manager): + """Extract embeddings from a list of galaxy observations.""" + all_embeddings = [] - with torch.no_grad(): - results = model.generate( - inputs=inputs, - targets=targets, - num_generations=1, - temperature=1.0 - ) - - return results + # Process in batches for efficiency + batch_size = 32 + for i in range(0, len(data_list), batch_size): + batch = data_list[i:i + batch_size] -# Analyze a galaxy -galaxy_properties = analyze_galaxy('ngc1234.fits', known_redshift=0.05) -``` + # Encode all modalities in the batch + batch_tokens = codec_manager.encode(*batch) -### Uncertainty Quantification + # Extract embeddings + with torch.no_grad(): + embeddings = model.encode(batch_tokens, num_encoder_tokens=600) + # Pool over sequence dimension + pooled = embeddings.mean(dim=1) # [batch, 768] -Generate multiple samples to estimate uncertainties: + all_embeddings.append(pooled.cpu().numpy()) -```python -def estimate_uncertainty(inputs, target, num_samples=100): - samples = [] + return np.vstack(all_embeddings) - with torch.no_grad(): - for _ in range(num_samples): - result = model.generate( - inputs=inputs, - targets=[target], - temperature=1.2 # Higher temperature for more diversity - ) - samples.append(result[target].value[0]) - - samples = np.array(samples) - return { - 'mean': np.mean(samples), - 'std': np.std(samples), - 'percentiles': np.percentile(samples, [16, 50, 84]) - } - -# Estimate redshift uncertainty -z_stats = estimate_uncertainty( - inputs={'image': galaxy_image}, - target='redshift' +# Usage example +galaxy_embeddings = extract_galaxy_embeddings( + [image1, image2, image3, ...], + model, + codec_manager ) -print(f"Redshift: {z_stats['mean']:.3f} ยฑ {z_stats['std']:.3f}") ``` -## Cross-Modal Translation - -### Image to Spectrum +### Workflow 2: Redshift Prediction -Convert imaging observations to spectroscopic predictions: +Predict redshift distributions from various input modalities: ```python -def image_to_spectrum(image, wavelength_range=(3800, 9200)): - """Generate spectrum from multi-band image.""" +def predict_redshift_distribution(inputs, model, codec_manager): + """Predict redshift probability distribution.""" + # Encode inputs + tokens = codec_manager.encode(*inputs) - # Generate spectrum tokens + # Predict redshift with torch.no_grad(): - result = model.generate( - inputs={'image': image}, - targets=['spectrum'] + predictions = model( + tokens, + target_mask={"tok_z": torch.zeros(len(inputs), 1)}, + num_encoder_tokens=600 ) - spectrum = result['spectrum'] - - # Filter to desired wavelength range - mask = (spectrum.wavelength >= wavelength_range[0]) & \ - (spectrum.wavelength <= wavelength_range[1]) + # Convert to probabilities + redshift_logits = predictions["tok_z"] + redshift_probs = torch.softmax(redshift_logits, dim=-1) - return { - 'wavelength': spectrum.wavelength[mask], - 'flux': spectrum.flux[mask] - } + return redshift_probs -# Generate and plot spectrum -synthetic_spec = image_to_spectrum(galaxy_image) -plt.plot(synthetic_spec['wavelength'], synthetic_spec['flux']) -plt.xlabel('Wavelength (ร…)') -plt.ylabel('Flux') -plt.title('AION-1 Generated Spectrum from Image') +# Example: Predict from photometry +redshift_dist = predict_redshift_distribution( + [flux_g, flux_r, flux_i, flux_z], + model, + codec_manager +) ``` -### Spectrum to Image +### Workflow 3: Reconstruction -Inverse translation - generate images from spectra: +Reconstruct modalities through the encode-decode process: ```python -def spectrum_to_image(spectrum, bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z']): - """Generate multi-band image from spectrum.""" - - with torch.no_grad(): - result = model.generate( - inputs={'spectrum': spectrum}, - targets=['image'], - target_bands=bands - ) +def reconstruct_modality(original_modality, model, codec_manager, modality_class, **metadata): + """Reconstruct a modality through encode-decode cycle.""" + # Encode original + tokens = codec_manager.encode(original_modality) + + # Decode back + reconstructed = codec_manager.decode( + tokens, + modality_class, + **metadata + ) - return result['image'] + return reconstructed -# Reconstruct galaxy appearance -reconstructed_image = spectrum_to_image(observed_spectrum) +# Example: Reconstruct image +reconstructed_image = reconstruct_modality( + original_image, + model, + codec_manager, + LegacySurveyImage, + bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] +) ``` -### Super-Resolution +## Embedding Extraction -Enhance low-resolution spectra using multimodal context: +### Basic Embedding Extraction ```python -def enhance_spectrum(low_res_spectrum, supporting_data=None): - """Enhance spectrum resolution using additional modalities.""" - - inputs = {'spectrum': low_res_spectrum} +def get_embeddings(modalities, model, codec_manager, pooling='mean'): + """Extract embeddings with different pooling strategies.""" + tokens = codec_manager.encode(*modalities) - # Add supporting data if available - if supporting_data: - inputs.update(supporting_data) - - # Generate high-resolution version with torch.no_grad(): - result = model.generate( - inputs=inputs, - targets=['spectrum_highres'], - num_generations=1 - ) + embeddings = model.encode(tokens, num_encoder_tokens=600) + + # Apply pooling + if pooling == 'mean': + return embeddings.mean(dim=1) + elif pooling == 'max': + return embeddings.max(dim=1)[0] + elif pooling == 'cls': + return embeddings[:, 0] # First token + else: + return embeddings # Return full sequence - return result['spectrum_highres'] - -# Example with photometric support -enhanced = enhance_spectrum( - sdss_spectrum, - supporting_data={ - 'flux_g': FluxG(value=[18.5]), - 'flux_r': FluxR(value=[17.2]) - } -) +# Usage +embeddings = get_embeddings([image, spectrum], model, codec_manager) ``` -## Representation Learning - -### Extracting Embeddings +### Multi-Modal Embeddings -Use AION-1's learned representations for downstream tasks: +Combine embeddings from different modalities: ```python -def extract_embeddings(data_dict, pool='mean'): - """Extract feature embeddings from AION-1 encoder.""" - - # Tokenize inputs - tokens = model.tokenize(data_dict) +def get_multimodal_embeddings(image, spectrum, photometry, model, codec_manager): + """Extract embeddings from multiple modality types.""" - # Get encoder representations - with torch.no_grad(): - embeddings = model.encode(tokens) - - # Pool over sequence dimension - if pool == 'mean': - features = embeddings.mean(dim=1) - elif pool == 'cls': - features = embeddings[:, 0] # First token - elif pool == 'max': - features = embeddings.max(dim=1)[0] - - return features.cpu().numpy() - -# Extract features for clustering -galaxy_features = extract_embeddings({ - 'image': galaxy_image, - 'spectrum': galaxy_spectrum -}) -``` + # Get embeddings from each modality type + image_tokens = codec_manager.encode(image) + spectrum_tokens = codec_manager.encode(spectrum) + photo_tokens = codec_manager.encode(*photometry) -### Similarity Search + embeddings = {} -Find similar objects using learned representations: + with torch.no_grad(): + # Image embeddings + img_emb = model.encode(image_tokens, num_encoder_tokens=300) + embeddings['image'] = img_emb.mean(dim=1) -```python -from sklearn.metrics.pairwise import cosine_similarity + # Spectrum embeddings + spec_emb = model.encode(spectrum_tokens, num_encoder_tokens=300) + embeddings['spectrum'] = spec_emb.mean(dim=1) -class GalaxySimilaritySearch: - def __init__(self, model): - self.model = model - self.database = [] - self.embeddings = [] - - def add_galaxy(self, galaxy_data, metadata=None): - """Add galaxy to search database.""" - embedding = extract_embeddings(galaxy_data) - self.embeddings.append(embedding) - self.database.append({ - 'data': galaxy_data, - 'metadata': metadata, - 'embedding': embedding - }) - - def find_similar(self, query_data, k=10): - """Find k most similar galaxies.""" - query_embedding = extract_embeddings(query_data) - - # Compute similarities - similarities = cosine_similarity( - query_embedding.reshape(1, -1), - np.vstack(self.embeddings) - )[0] - - # Get top k - indices = np.argsort(similarities)[::-1][:k] - - return [(self.database[i], similarities[i]) for i in indices] + # Combined embeddings + all_tokens = {**image_tokens, **spectrum_tokens, **photo_tokens} + combined_emb = model.encode(all_tokens, num_encoder_tokens=900) + embeddings['combined'] = combined_emb.mean(dim=1) -# Usage -searcher = GalaxySimilaritySearch(model) -# ... add galaxies to database ... -similar_galaxies = searcher.find_similar(query_galaxy, k=5) + return embeddings ``` -### Anomaly Detection +## Similarity Search -Identify unusual objects using reconstruction error: +Implement similarity search using AION embeddings: ```python -def detect_anomalies(galaxies, threshold_percentile=95): - """Detect anomalous galaxies using reconstruction error.""" - - reconstruction_errors = [] +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.neighbors import NearestNeighbors - for galaxy in galaxies: - # Encode and decode +class AIONSimilaritySearch: + def __init__(self, model, codec_manager): + self.model = model + self.codec_manager = codec_manager + self.database_embeddings = [] + self.database_objects = [] + self.index = None + + def add_objects(self, objects): + """Add objects to the search database.""" + for obj in objects: + # Extract embedding + tokens = self.codec_manager.encode(*obj['modalities']) + with torch.no_grad(): + emb = self.model.encode(tokens, num_encoder_tokens=600) + emb = emb.mean(dim=1).cpu().numpy() + + self.database_embeddings.append(emb) + self.database_objects.append(obj) + + # Build search index + if self.database_embeddings: + embeddings_matrix = np.vstack(self.database_embeddings) + self.index = NearestNeighbors(n_neighbors=10, metric='cosine') + self.index.fit(embeddings_matrix) + + def search(self, query_modalities, k=5): + """Search for similar objects.""" + # Get query embedding + tokens = self.codec_manager.encode(*query_modalities) with torch.no_grad(): - reconstructed = model.generate( - inputs=galaxy, - targets=list(galaxy.keys()) - ) - - # Compute reconstruction error - error = 0 - for key in galaxy: - if key == 'image': - error += np.mean((galaxy[key].flux - - reconstructed[key].flux)**2) - elif hasattr(galaxy[key], 'value'): - error += np.mean((galaxy[key].value - - reconstructed[key].value)**2) - - reconstruction_errors.append(error) - - # Set threshold - threshold = np.percentile(reconstruction_errors, threshold_percentile) - - # Identify anomalies - anomalies = [g for g, e in zip(galaxies, reconstruction_errors) - if e > threshold] - - return anomalies, reconstruction_errors + query_emb = self.model.encode(tokens, num_encoder_tokens=600) + query_emb = query_emb.mean(dim=1).cpu().numpy() + + # Find nearest neighbors + distances, indices = self.index.kneighbors(query_emb, n_neighbors=k) + + results = [] + for i, idx in enumerate(indices[0]): + results.append({ + 'object': self.database_objects[idx], + 'similarity': 1 - distances[0][i], # Convert distance to similarity + 'rank': i + 1 + }) + + return results + +# Usage example +searcher = AIONSimilaritySearch(model, codec_manager) + +# Add objects to database +database_objects = [ + {'modalities': [image1, spectrum1], 'metadata': {'id': 'galaxy_1'}}, + {'modalities': [image2, spectrum2], 'metadata': {'id': 'galaxy_2'}}, + # ... more objects +] +searcher.add_objects(database_objects) + +# Search for similar objects +query_galaxy = [query_image, query_spectrum] +similar_objects = searcher.search(query_galaxy, k=10) + +print(f"Found {len(similar_objects)} similar objects:") +for result in similar_objects: + print(f"Rank {result['rank']}: {result['object']['metadata']['id']} " + f"(similarity: {result['similarity']:.3f})") ``` -## Advanced Applications +## Property Prediction -### Multi-Survey Integration +Use AION embeddings for various prediction tasks: -Combine observations from different surveys: +### Redshift Estimation with k-NN ```python -def integrate_multi_survey(hsc_image, sdss_spectrum, desi_spectrum=None): - """Integrate observations from multiple surveys.""" +from sklearn.neighbors import KNeighborsRegressor +from sklearn.model_selection import train_test_split +from sklearn.metrics import mean_absolute_error, r2_score - inputs = { - 'image': hsc_image, - 'spectrum_sdss': sdss_spectrum - } +def train_redshift_predictor(galaxies_with_redshifts, model, codec_manager): + """Train a k-NN regressor for redshift prediction.""" - if desi_spectrum: - inputs['spectrum_desi'] = desi_spectrum + # Extract embeddings and targets + embeddings = [] + redshifts = [] - # Generate unified representation - with torch.no_grad(): - # Extract all available properties - results = model.generate( - inputs=inputs, - targets=['redshift', 'stellar_mass', 'sfr', 'metallicity'] - ) + for galaxy in galaxies_with_redshifts: + tokens = codec_manager.encode(*galaxy['modalities']) + with torch.no_grad(): + emb = model.encode(tokens, num_encoder_tokens=600) + emb = emb.mean(dim=1).cpu().numpy() - # Generate missing modalities - if not desi_spectrum: - results['spectrum_desi'] = model.generate( - inputs=inputs, - targets=['spectrum_desi'] - )['spectrum_desi'] + embeddings.append(emb[0]) # Remove batch dimension + redshifts.append(galaxy['redshift']) - return results -``` + X = np.array(embeddings) + y = np.array(redshifts) -### Time Series Analysis + # Split data + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) -Analyze variable objects across epochs: + # Train k-NN regressor + knn = KNeighborsRegressor(n_neighbors=5) + knn.fit(X_train, y_train) -```python -def analyze_variable_object(observations): - """ - Analyze time-variable astronomical object. + # Evaluate + y_pred = knn.predict(X_test) + mae = mean_absolute_error(y_test, y_pred) + r2 = r2_score(y_test, y_pred) - observations: list of (time, data_dict) tuples - """ + print(f"Redshift prediction - MAE: {mae:.4f}, Rยฒ: {r2:.4f}") - embeddings_over_time = [] - properties_over_time = [] + return knn - for time, data in observations: - # Extract embeddings - embedding = extract_embeddings(data) - embeddings_over_time.append(embedding) +def predict_redshift(new_galaxy, trained_model, model, codec_manager): + """Predict redshift for a new galaxy.""" + tokens = codec_manager.encode(*new_galaxy) + with torch.no_grad(): + emb = model.encode(tokens, num_encoder_tokens=600) + emb = emb.mean(dim=1).cpu().numpy() - # Predict properties - with torch.no_grad(): - props = model.generate( - inputs=data, - targets=['flux_g', 'flux_r', 'temperature'] - ) - - properties_over_time.append({ - 'time': time, - 'properties': props, - 'embedding': embedding - }) - - # Analyze evolution - embeddings = np.vstack(embeddings_over_time) - - # Detect significant changes - embedding_distances = np.sqrt(np.sum(np.diff(embeddings, axis=0)**2, axis=1)) - change_points = np.where(embedding_distances > np.std(embedding_distances) * 2)[0] - - return { - 'properties': properties_over_time, - 'change_points': change_points, - 'embedding_evolution': embeddings - } + predicted_z = trained_model.predict(emb)[0] + return predicted_z ``` -### Physical Parameter Estimation - -Estimate astrophysical parameters with uncertainty: +### Stellar Mass Prediction ```python -class PhysicalParameterEstimator: - def __init__(self, model, num_samples=100): - self.model = model - self.num_samples = num_samples - - def estimate_parameters(self, observations): - """Estimate physical parameters with uncertainties.""" +from sklearn.ensemble import RandomForestRegressor - # Parameters to estimate - parameters = [ - 'redshift', 'stellar_mass', 'sfr', - 'metallicity', 'age', 'extinction_v' - ] +def train_stellar_mass_predictor(galaxies_with_masses, model, codec_manager): + """Train predictor for stellar mass estimation.""" - # Generate multiple samples - samples = {param: [] for param in parameters} + # Similar to redshift prediction but for stellar mass + embeddings = [] + masses = [] + for galaxy in galaxies_with_masses: + tokens = codec_manager.encode(*galaxy['modalities']) with torch.no_grad(): - for _ in range(self.num_samples): - results = self.model.generate( - inputs=observations, - targets=parameters, - temperature=1.1 - ) - - for param in parameters: - if param in results: - samples[param].append(results[param].value[0]) - - # Compute statistics - estimates = {} - for param, values in samples.items(): - if values: - values = np.array(values) - estimates[param] = { - 'median': np.median(values), - 'mean': np.mean(values), - 'std': np.std(values), - 'ci_68': np.percentile(values, [16, 84]), - 'ci_95': np.percentile(values, [2.5, 97.5]) - } - - return estimates + emb = model.encode(tokens, num_encoder_tokens=600) + emb = emb.mean(dim=1).cpu().numpy() -# Usage -estimator = PhysicalParameterEstimator(model) -parameters = estimator.estimate_parameters({ - 'image': galaxy_image, - 'spectrum': galaxy_spectrum -}) - -print(f"Stellar Mass: {parameters['stellar_mass']['median']:.2e} " - f"+/- {parameters['stellar_mass']['std']:.2e} M_sun") -``` + embeddings.append(emb[0]) + masses.append(np.log10(galaxy['stellar_mass'])) # Log stellar mass -## Performance Optimization - -### Efficient Batch Processing - -```python -from torch.utils.data import DataLoader, Dataset + X = np.array(embeddings) + y = np.array(masses) -class AIONDataset(Dataset): - def __init__(self, data_list): - self.data = data_list + # Train Random Forest + rf = RandomForestRegressor(n_estimators=100, random_state=42) + rf.fit(X, y) - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - return self.data[idx] - -def process_large_dataset(data_list, batch_size=32): - """Efficiently process large datasets.""" - - dataset = AIONDataset(data_list) - dataloader = DataLoader(dataset, batch_size=batch_size, - num_workers=4, pin_memory=True) + return rf +``` - all_results = [] +## Performance Tips - with torch.no_grad(): - for batch in dataloader: - # Process batch - results = model.generate( - inputs=batch, - targets=['redshift', 'stellar_mass'] - ) - all_results.append(results) - - # Concatenate results - return {k: np.concatenate([r[k].value for r in all_results]) - for k in all_results[0]} -``` +### Batch Processing -### Memory-Efficient Processing +Process multiple objects efficiently: ```python -def process_with_chunking(large_spectrum, chunk_size=1000): - """Process very long spectra in chunks.""" +def process_batch_efficiently(object_list, model, codec_manager, batch_size=32): + """Process objects in batches for better GPU utilization.""" + results = [] - n_chunks = len(large_spectrum.wavelength) // chunk_size + 1 - chunk_results = [] + for i in range(0, len(object_list), batch_size): + batch = object_list[i:i + batch_size] - for i in range(n_chunks): - start = i * chunk_size - end = min((i + 1) * chunk_size, len(large_spectrum.wavelength)) + # Group by modality type for efficient encoding + images = [obj for obj in batch if 'image' in obj] + spectra = [obj for obj in batch if 'spectrum' in obj] - chunk = Spectrum( - wavelength=large_spectrum.wavelength[start:end], - flux=large_spectrum.flux[start:end] - ) + batch_results = [] with torch.no_grad(): - result = model.process_spectrum_chunk(chunk) - chunk_results.append(result) + # Process images + if images: + image_batch = [obj['image'] for obj in images] + tokens = codec_manager.encode(*image_batch) + embeddings = model.encode(tokens, num_encoder_tokens=600) + batch_results.extend(embeddings.mean(dim=1).cpu().numpy()) + + # Process spectra + if spectra: + spectrum_batch = [obj['spectrum'] for obj in spectra] + tokens = codec_manager.encode(*spectrum_batch) + embeddings = model.encode(tokens, num_encoder_tokens=300) + batch_results.extend(embeddings.mean(dim=1).cpu().numpy()) + + results.extend(batch_results) - # Combine chunks - return combine_spectrum_chunks(chunk_results) + return results ``` -### GPU Memory Management +### Memory Management -```python -import gc +Handle large datasets with limited GPU memory: -def memory_efficient_generation(inputs, targets, max_batch=16): - """Generate with automatic batch size adjustment.""" +```python +def process_large_dataset(dataset, model, codec_manager, max_batch_size=16): + """Process large datasets with automatic memory management.""" + import gc - batch_size = max_batch + current_batch_size = max_batch_size + results = [] - while batch_size > 0: + i = 0 + while i < len(dataset): try: + batch = dataset[i:i + current_batch_size] + + # Process batch + batch_tokens = codec_manager.encode(*batch) with torch.no_grad(): - results = model.generate( - inputs=inputs, - targets=targets, - batch_size=batch_size - ) - return results + embeddings = model.encode(batch_tokens, num_encoder_tokens=600) + results.append(embeddings.mean(dim=1).cpu()) + + i += current_batch_size except torch.cuda.OutOfMemoryError: - # Clear cache and try smaller batch + # Clear memory and reduce batch size torch.cuda.empty_cache() gc.collect() - batch_size //= 2 + current_batch_size = max(1, current_batch_size // 2) + print(f"Reduced batch size to {current_batch_size}") - if batch_size == 0: - raise RuntimeError("Cannot fit even batch size 1") + if current_batch_size == 0: + raise RuntimeError("Cannot process even single example") - raise RuntimeError("Failed to process") + return torch.cat(results, dim=0) ``` -## Best Practices +### Using Mixed Precision -### 1. Data Preparation -- Always normalize and preprocess data according to survey specifications -- Provide inverse variance when available for optimal results -- Use appropriate data types for each modality +Speed up inference with automatic mixed precision: -### 2. Model Selection -- Use `aion-tiny` for quick experiments and limited GPU memory -- Use `aion-base` for most research applications -- Use `aion-large` for highest accuracy when computational resources permit +```python +def extract_embeddings_amp(modalities, model, codec_manager): + """Extract embeddings using automatic mixed precision.""" + from torch.cuda.amp import autocast -### 3. Generation Settings -- Lower temperature (0.8-1.0) for more deterministic outputs -- Higher temperature (1.1-1.5) for diversity and uncertainty estimation -- Multiple generations for robust uncertainty quantification + tokens = codec_manager.encode(*modalities) -### 4. Error Handling -```python -def safe_generate(model, inputs, targets, fallback=None): - """Safely generate with error handling.""" - try: - return model.generate(inputs=inputs, targets=targets) - except Exception as e: - print(f"Generation failed: {e}") - return fallback or {t: None for t in targets} + with torch.no_grad(): + with autocast(): + embeddings = model.encode(tokens, num_encoder_tokens=600) + + return embeddings.float() # Convert back to float32 ``` -## Conclusion +## Best Practices + +1. **Always use `.eval()` mode** for inference to disable dropout and batch norm updates +2. **Use `torch.no_grad()`** to disable gradient computation and save memory +3. **Process in batches** when possible for better GPU utilization +4. **Pool embeddings appropriately** - mean pooling works well for most tasks +5. **Use consistent device placement** - ensure all tensors are on the same device +6. **Clear GPU cache** periodically when processing large datasets + +## Troubleshooting + +### Common Issues -AION-1 provides a powerful and flexible framework for multimodal astronomical analysis. Its ability to seamlessly integrate diverse observations enables new research possibilities: +1. **CUDA out of memory**: Reduce batch size or use gradient checkpointing +2. **Slow processing**: Ensure data is on GPU and use batch processing +3. **Shape mismatches**: Check that tensor dimensions match expected format +4. **Device errors**: Ensure model, data, and codec_manager are on same device -- Cross-modal prediction and generation -- Unified analysis across multiple surveys -- Robust uncertainty quantification -- Discovery of unusual objects -- Efficient processing of large datasets +### Debug Mode + +```python +def debug_tokens(tokens, codec_manager): + """Debug token shapes and contents.""" + print("Token summary:") + for key, tensor in tokens.items(): + print(f" {key}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}") + print(f" range: [{tensor.min().item():.2f}, {tensor.max().item():.2f}]") +``` -For more examples and the latest updates, visit the [AION GitHub repository](https://github.com/polymathic-ai/aion) and join our community discussions. +For more advanced examples and the latest updates, see the [Tutorial Notebook](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb).