Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

## 🎯 Overview

<div align="center">
<img src="assets/aion.png" alt="AION Logo" width="600">
</div>

AION-1 is a cutting-edge large omnimodal model specifically designed for astronomical surveys. It seamlessly integrates multiple data modalities, and enables simple adaptation to a wide range of astronomical tasks.


Expand Down Expand Up @@ -142,7 +146,7 @@ flux_g = LegacySurveyFluxG(value=torch.tensor([flux_values]))

---

## 💡 Key Use Cases
## 💡 Example Use Cases

### 🔍 Similarity Search
Find galaxies similar to a query object across different modalities:
Expand Down
38 changes: 38 additions & 0 deletions aion/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ class HSCImage(Image):
"""HSC image modality data."""

token_key: ClassVar[str] = "tok_image_hsc"
num_tokens: ClassVar[int] = 576


class LegacySurveyImage(Image):
"""Legacy Survey image modality data."""

token_key: ClassVar[str] = "tok_image"
num_tokens: ClassVar[int] = 576


@dataclass
Expand Down Expand Up @@ -109,12 +111,14 @@ class DESISpectrum(Spectrum):
"""DESI spectrum modality data."""

token_key: ClassVar[str] = "tok_spectrum_desi"
num_tokens: ClassVar[int] = 273


class SDSSSpectrum(Spectrum):
"""SDSS spectrum modality data."""

token_key: ClassVar[str] = "tok_spectrum_sdss"
num_tokens: ClassVar[int] = 273


# Catalog modality
Expand Down Expand Up @@ -170,55 +174,63 @@ class LegacySurveyFluxG(Scalar):

name: ClassVar[str] = "FLUX_G"
token_key: ClassVar[str] = "tok_flux_g"
num_tokens: ClassVar[int] = 1


class LegacySurveyFluxR(Scalar):
"""R-band flux measurement."""

name: ClassVar[str] = "FLUX_R"
token_key: ClassVar[str] = "tok_flux_r"
num_tokens: ClassVar[int] = 1


class LegacySurveyFluxI(Scalar):
"""I-band flux measurement."""

name: ClassVar[str] = "FLUX_I"
token_key: ClassVar[str] = "tok_flux_i"
num_tokens: ClassVar[int] = 1


class LegacySurveyFluxZ(Scalar):
"""Z-band flux measurement."""

name: ClassVar[str] = "FLUX_Z"
token_key: ClassVar[str] = "tok_flux_z"
num_tokens: ClassVar[int] = 1


class LegacySurveyFluxW1(Scalar):
"""WISE W1-band flux measurement."""

name: ClassVar[str] = "FLUX_W1"
token_key: ClassVar[str] = "tok_flux_w1"
num_tokens: ClassVar[int] = 1


class LegacySurveyFluxW2(Scalar):
"""WISE W2-band flux measurement."""

name: ClassVar[str] = "FLUX_W2"
token_key: ClassVar[str] = "tok_flux_w2"
num_tokens: ClassVar[int] = 1


class LegacySurveyFluxW3(Scalar):
"""WISE W3-band flux measurement."""

name: ClassVar[str] = "FLUX_W3"
token_key: ClassVar[str] = "tok_flux_w3"
num_tokens: ClassVar[int] = 1


class LegacySurveyFluxW4(Scalar):
"""WISE W4-band flux measurement."""

name: ClassVar[str] = "FLUX_W4"
token_key: ClassVar[str] = "tok_flux_w4"
num_tokens: ClassVar[int] = 1


# Shape parameters
Expand All @@ -227,20 +239,23 @@ class LegacySurveyShapeR(Scalar):

name: ClassVar[str] = "SHAPE_R"
token_key: ClassVar[str] = "tok_shape_r"
num_tokens: ClassVar[int] = 1


class LegacySurveyShapeE1(Scalar):
"""First ellipticity component."""

name: ClassVar[str] = "SHAPE_E1"
token_key: ClassVar[str] = "tok_shape_e1"
num_tokens: ClassVar[int] = 1


class LegacySurveyShapeE2(Scalar):
"""Second ellipticity component."""

name: ClassVar[str] = "SHAPE_E2"
token_key: ClassVar[str] = "tok_shape_e2"
num_tokens: ClassVar[int] = 1


# Other scalar properties
Expand All @@ -249,6 +264,7 @@ class LegacySurveyEBV(Scalar):

name: ClassVar[str] = "EBV"
token_key: ClassVar[str] = "tok_ebv"
num_tokens: ClassVar[int] = 1


# Spectroscopic redshift
Expand All @@ -257,6 +273,7 @@ class Z(Scalar):

name: ClassVar[str] = "Z"
token_key: ClassVar[str] = "tok_z"
num_tokens: ClassVar[int] = 1


# Extinction values from HSC
Expand All @@ -265,90 +282,103 @@ class HSCAG(Scalar):

name: ClassVar[str] = "a_g"
token_key: ClassVar[str] = "tok_a_g"
num_tokens: ClassVar[int] = 1


class HSCAR(Scalar):
"""HSC a_r extinction."""

name: ClassVar[str] = "a_r"
token_key: ClassVar[str] = "tok_a_r"
num_tokens: ClassVar[int] = 1


class HSCAI(Scalar):
"""HSC a_i extinction."""

name: ClassVar[str] = "a_i"
token_key: ClassVar[str] = "tok_a_i"
num_tokens: ClassVar[int] = 1


class HSCAZ(Scalar):
"""HSC a_z extinction."""

name: ClassVar[str] = "a_z"
token_key: ClassVar[str] = "tok_a_z"
num_tokens: ClassVar[int] = 1


class HSCAY(Scalar):
"""HSC a_y extinction."""

name: ClassVar[str] = "a_y"
token_key: ClassVar[str] = "tok_a_y"
num_tokens: ClassVar[int] = 1


class HSCMagG(Scalar):
"""HSC g-band cmodel magnitude."""

name: ClassVar[str] = "g_cmodel_mag"
token_key: ClassVar[str] = "tok_mag_g"
num_tokens: ClassVar[int] = 1


class HSCMagR(Scalar):
"""HSC r-band cmodel magnitude."""

name: ClassVar[str] = "r_cmodel_mag"
token_key: ClassVar[str] = "tok_mag_r"
num_tokens: ClassVar[int] = 1


class HSCMagI(Scalar):
"""HSC i-band cmodel magnitude."""

name: ClassVar[str] = "i_cmodel_mag"
token_key: ClassVar[str] = "tok_mag_i"
num_tokens: ClassVar[int] = 1


class HSCMagZ(Scalar):
"""HSC z-band cmodel magnitude."""

name: ClassVar[str] = "z_cmodel_mag"
token_key: ClassVar[str] = "tok_mag_z"
num_tokens: ClassVar[int] = 1


class HSCMagY(Scalar):
"""HSC y-band cmodel magnitude."""

name: ClassVar[str] = "y_cmodel_mag"
token_key: ClassVar[str] = "tok_mag_y"
num_tokens: ClassVar[int] = 1


class HSCShape11(Scalar):
"""HSC i-band SDSS shape 11 component."""

name: ClassVar[str] = "i_sdssshape_shape11"
token_key: ClassVar[str] = "tok_shape11"
num_tokens: ClassVar[int] = 1


class HSCShape22(Scalar):
"""HSC i-band SDSS shape 22 component."""

name: ClassVar[str] = "i_sdssshape_shape22"
token_key: ClassVar[str] = "tok_shape22"
num_tokens: ClassVar[int] = 1


class HSCShape12(Scalar):
"""HSC i-band SDSS shape 12 component."""

name: ClassVar[str] = "i_sdssshape_shape12"
token_key: ClassVar[str] = "tok_shape12"
num_tokens: ClassVar[int] = 1


# Gaia modalities
Expand All @@ -357,55 +387,63 @@ class GaiaFluxG(Scalar):

name: ClassVar[str] = "phot_g_mean_flux"
token_key: ClassVar[str] = "tok_flux_g_gaia"
num_tokens: ClassVar[int] = 1


class GaiaFluxBp(Scalar):
"""Gaia BP-band mean flux."""

name: ClassVar[str] = "phot_bp_mean_flux"
token_key: ClassVar[str] = "tok_flux_bp_gaia"
num_tokens: ClassVar[int] = 1


class GaiaFluxRp(Scalar):
"""Gaia RP-band mean flux."""

name: ClassVar[str] = "phot_rp_mean_flux"
token_key: ClassVar[str] = "tok_flux_rp_gaia"
num_tokens: ClassVar[int] = 1


class GaiaParallax(Scalar):
"""Gaia parallax measurement."""

name: ClassVar[str] = "parallax"
token_key: ClassVar[str] = "tok_parallax"
num_tokens: ClassVar[int] = 1


class Ra(Scalar):
"""Right ascension coordinate."""

name: ClassVar[str] = "ra"
token_key: ClassVar[str] = "tok_ra"
num_tokens: ClassVar[int] = 1


class Dec(Scalar):
"""Declination coordinate."""

name: ClassVar[str] = "dec"
token_key: ClassVar[str] = "tok_dec"
num_tokens: ClassVar[int] = 1


class GaiaXpBp(Scalar):
"""Gaia BP spectral coefficients."""

name: ClassVar[str] = "bp_coefficients"
token_key: ClassVar[str] = "tok_xp_bp"
num_tokens: ClassVar[int] = 55


class GaiaXpRp(Scalar):
"""Gaia RP spectral coefficients."""

name: ClassVar[str] = "rp_coefficients"
token_key: ClassVar[str] = "tok_xp_rp"
num_tokens: ClassVar[int] = 55


ScalarModalities = {
Expand Down
57 changes: 50 additions & 7 deletions aion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,66 @@ def encode(
def forward(
self,
input_dict: Dict[str, torch.Tensor],
target_mask: Dict[str, torch.Tensor],
target_modality: list[object],
input_mask: Optional[Dict[str, torch.Tensor]] = None,
num_decoder_tokens: int = 256,
num_encoder_tokens: int = 256,
) -> torch.Tensor:
"""
The forward function returns the logits of the requested target outputs, given the input data.
Helpful function to compute the logits of the requested target outputs, given the input data.

Args:
input_dict (Dict[str, torch.Tensor]): Input data dictionary.
target_mask (Dict[str, torch.Tensor]): Target mask dictionary, defines which modalities to predict, and which tokens within that modality.
input_mask (Dict[str, torch.Tensor], optional): Input mask dictionary. Defaults to None.
num_encoder_tokens (int, optional): Maximum number of encoder tokens. Defaults to 256.
target_modality (list[object]): List of target modalities to be predicted.
input_mask (Dict[str, torch.Tensor], optional): Mask dictionary. Defaults to None.

Returns:
torch.Tensor: Output tensor of the model.
"""
# Get batch size:
B = list(input_dict.values())[0].shape[0]

# Dynamically compute the number of encoder tokens
num_encoder_tokens = 0
for mod in input_dict.keys():
num_encoder_tokens += (
input_dict[mod].shape[1] if input_dict[mod].dim() == 2 else 1
)

# Dynamically build the target mask and decoder tokens
target_mask = {}
num_decoder_tokens = 0
target_modality = (
[target_modality]
if not isinstance(target_modality, list)
else target_modality
)
for mod in target_modality:
target_mask[mod.token_key] = torch.zeros(B, mod.num_tokens).to(torch.bool)
num_decoder_tokens += mod.num_tokens

logit_dict = self._forward(
input_dict,
target_mask=target_mask,
input_mask=input_mask,
num_decoder_tokens=num_decoder_tokens,
num_encoder_tokens=num_encoder_tokens,
)

for mod in logit_dict.keys():
logit_dict[mod] = logit_dict[mod].view(B, target_mask[mod].shape[1], -1)

return logit_dict

def _forward(
self,
input_dict: Dict[str, torch.Tensor],
target_mask: Dict[str, torch.Tensor],
input_mask: Optional[Dict[str, torch.Tensor]] = None,
num_decoder_tokens: int = 256,
num_encoder_tokens: int = 256,
) -> torch.Tensor:
"""
The forward function returns the logits of the requested target outputs, given the input data.
"""
# Embedding inputs and targets
encoder_tokens, encoder_emb, encoder_mask, _ = self.embed_inputs(
input_dict, mask=input_mask, num_encoder_tokens=num_encoder_tokens
Expand Down
Binary file added assets/aion.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading