-
Notifications
You must be signed in to change notification settings - Fork 0
WIP: Prototype of minimal codebase #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,4 +206,6 @@ marimo/_static/ | |
| marimo/_lsp/ | ||
| __marimo__/ | ||
|
|
||
| shine/_version.py | ||
| shine/_version.py | ||
|
|
||
| results/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| image: | ||
| pixel_scale: 0.2 | ||
| size_x: 32 | ||
| size_y: 32 | ||
| n_objects: 1 | ||
| noise: | ||
| type: Gaussian | ||
| sigma: 0.01 | ||
|
|
||
| psf: | ||
| type: Gaussian | ||
| sigma: 0.1 | ||
|
|
||
| gal: | ||
| type: Exponential | ||
| flux: | ||
| type: LogNormal | ||
| mean: 100.0 | ||
| sigma: 0.1 | ||
| half_light_radius: | ||
| type: Uniform | ||
| min: 0.3 | ||
| max: 0.8 | ||
| shear: | ||
| type: G1G2 | ||
| g1: | ||
| type: Normal | ||
| mean: 0.05 | ||
| sigma: 0.02 | ||
| g2: | ||
| type: Normal | ||
| mean: -0.05 | ||
| sigma: 0.02 | ||
|
|
||
| inference: | ||
| warmup: 100 | ||
| samples: 100 | ||
| chains: 1 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| from typing import Union, Optional, Dict, Any, List | ||
| from pydantic import BaseModel, Field, validator | ||
| import yaml | ||
| from pathlib import Path | ||
|
|
||
| # --- Distribution Models (for Priors) --- | ||
|
|
||
| class DistributionConfig(BaseModel): | ||
| type: str | ||
| mean: Optional[float] = None | ||
| sigma: Optional[float] = None | ||
| min: Optional[float] = None | ||
| max: Optional[float] = None | ||
|
|
||
| # Allow extra fields for other distributions | ||
| class Config: | ||
| extra = "allow" | ||
|
|
||
| # --- Component Models --- | ||
|
|
||
| class NoiseConfig(BaseModel): | ||
| type: str = "Gaussian" | ||
| sigma: float | ||
|
|
||
| class ImageConfig(BaseModel): | ||
| pixel_scale: float | ||
| size_x: int | ||
| size_y: int | ||
| n_objects: int = 1 # Default to 1 for simple tests | ||
| noise: NoiseConfig | ||
|
|
||
| class PSFConfig(BaseModel): | ||
| type: str = "Gaussian" | ||
| sigma: float | ||
| beta: Optional[float] = 2.5 # For Moffat | ||
|
|
||
| class ShearComponentConfig(BaseModel): | ||
| # Can be a fixed float or a distribution | ||
| type: Optional[str] = None # If None, assume fixed value in parent or handled elsewhere | ||
| mean: Optional[float] = 0.0 | ||
| sigma: Optional[float] = 0.05 | ||
|
|
||
| # To handle the case where it's just a float in YAML, we might need a custom validator | ||
| # but for now let's assume structured input as per design doc | ||
|
|
||
| class ShearConfig(BaseModel): | ||
| type: str = "G1G2" | ||
| g1: Union[float, DistributionConfig] | ||
| g2: Union[float, DistributionConfig] | ||
|
|
||
| class GalaxyConfig(BaseModel): | ||
| type: str = "Exponential" # Changed default from Sersic to Exponential | ||
| n: Optional[Union[float, DistributionConfig]] = None # Make optional for Exponential | ||
| flux: Union[float, DistributionConfig] | ||
| half_light_radius: Union[float, DistributionConfig] = Field(..., alias="half_light_radius") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does |
||
| shear: ShearConfig | ||
|
|
||
| class InferenceConfig(BaseModel): | ||
| warmup: int = 500 | ||
| samples: int = 1000 | ||
| chains: int = 1 | ||
| dense_mass: bool = False | ||
|
|
||
| class ShineConfig(BaseModel): | ||
| image: ImageConfig | ||
| psf: PSFConfig | ||
| gal: GalaxyConfig | ||
| inference: InferenceConfig = Field(default_factory=InferenceConfig) | ||
| data_path: Optional[str] = None | ||
| output_path: str = "results" | ||
|
|
||
| class ConfigHandler: | ||
| @staticmethod | ||
| def load(path: str) -> ShineConfig: | ||
| path = Path(path) | ||
| if not path.exists(): | ||
| raise FileNotFoundError(f"Config file not found: {path}") | ||
|
|
||
| with open(path, "r") as f: | ||
| data = yaml.safe_load(f) | ||
|
|
||
| # Basic validation and type conversion via Pydantic | ||
| return ShineConfig(**data) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import jax.numpy as jnp | ||
| import jax | ||
| from dataclasses import dataclass | ||
| from typing import Optional, Any, Dict | ||
| from shine.config import ShineConfig | ||
|
|
||
| @dataclass | ||
| class Observation: | ||
| image: jnp.ndarray | ||
| noise_map: jnp.ndarray | ||
| psf_config: Dict[str, Any] # Store PSF config instead of object | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using the Class |
||
| wcs: Any = None | ||
|
|
||
| class DataLoader: | ||
| @staticmethod | ||
| def load(config: ShineConfig) -> Observation: | ||
| if config.data_path and config.data_path != "None": | ||
| # TODO: Implement real data loading (Fits/HDF5) | ||
| raise NotImplementedError("Real data loading not yet implemented. Use synthetic generation.") | ||
| else: | ||
| print("No data path provided. Generating synthetic data...") | ||
| return DataLoader.generate_synthetic(config) | ||
|
|
||
| @staticmethod | ||
| def generate_synthetic(config: ShineConfig) -> Observation: | ||
| import galsim | ||
|
|
||
| # 1. Define PSF | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually this will be a single line: and all the code should be written in the psf_utils sub package/module. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same applies to the following lines. |
||
| if config.psf.type == "Gaussian": | ||
| psf = galsim.Gaussian(sigma=config.psf.sigma) | ||
| else: | ||
| raise NotImplementedError(f"PSF type {config.psf.type} not supported for synthetic gen") | ||
|
|
||
| # 2. Define Galaxy (using mean values from config for "truth") | ||
| def get_mean(param): | ||
| if isinstance(param, (float, int)): | ||
| return float(param) | ||
| # If it's a distribution config | ||
| if param.mean is not None: | ||
| return param.mean | ||
| # Handle Uniform | ||
| if param.type == 'Uniform' and param.min is not None and param.max is not None: | ||
| return (param.min + param.max) / 2.0 | ||
| return param.mean # Fallback (might still be None if not handled) | ||
|
|
||
| gal_flux = get_mean(config.gal.flux) | ||
| gal_hlr = get_mean(config.gal.half_light_radius) | ||
|
|
||
| # Shear | ||
| g1 = get_mean(config.gal.shear.g1) | ||
| g2 = get_mean(config.gal.shear.g2) | ||
| shear = galsim.Shear(g1=g1, g2=g2) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing intrinsic ellipticity. |
||
| # Create Galaxy Object - Use Exponential (Sersic n=1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In principle if we follow the previous logic it should check the galaxy type: |
||
| gal = galsim.Exponential(half_light_radius=gal_hlr, flux=gal_flux) | ||
| gal = gal.shear(shear) | ||
|
|
||
| # Convolve | ||
| final = galsim.Convolve([gal, psf]) | ||
|
|
||
| # 3. Draw Image | ||
| image = final.drawImage(nx=config.image.size_x, | ||
| ny=config.image.size_y, | ||
| scale=config.image.pixel_scale).array | ||
|
|
||
| # 4. Add Noise | ||
| rng = galsim.BaseDeviate(0) | ||
| noise_sigma = config.image.noise.sigma | ||
| noise = galsim.GaussianNoise(rng, sigma=noise_sigma) | ||
|
|
||
| # GalSim image for noise addition | ||
| gs_image = galsim.Image(image) | ||
| gs_image.addNoise(noise) | ||
| noisy_image = gs_image.array | ||
|
|
||
| noise_map = jnp.ones_like(noisy_image) * (noise_sigma**2) | ||
|
|
||
| # Return JAX arrays and PSF config | ||
| return Observation( | ||
| image=jnp.array(noisy_image), | ||
| noise_map=noise_map, | ||
| psf_config={'type': config.psf.type, 'sigma': config.psf.sigma} | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import jax | ||
| import numpyro | ||
| from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO | ||
| from numpyro.infer.autoguide import AutoDelta | ||
| from typing import Dict, Any | ||
| import arviz as az | ||
|
|
||
| class HMCInference: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be named simply Inference because it could implement various inference methods (HMC, MCMC, etc). |
||
| def __init__(self, model, num_warmup=500, num_samples=1000, num_chains=1, dense_mass=False): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why isn't the class |
||
| self.model = model | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not clear what type of object model is. |
||
| self.num_warmup = num_warmup | ||
| self.num_samples = num_samples | ||
| self.num_chains = num_chains | ||
| self.dense_mass = dense_mass | ||
|
|
||
| def run(self, rng_key, observed_data, extra_args=None): | ||
| if extra_args is None: | ||
| extra_args = {} | ||
|
|
||
| kernel = NUTS(self.model, dense_mass=self.dense_mass) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do NUTS and MCMC deal with initial samples? |
||
| mcmc = MCMC(kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=self.num_chains) | ||
|
|
||
| mcmc.run(rng_key, observed_data=observed_data, **extra_args) | ||
| mcmc.print_summary() | ||
|
|
||
| # Convert to ArviZ InferenceData | ||
| return az.from_numpyro(mcmc) | ||
|
|
||
| class MAPInference: | ||
| def __init__(self, model, num_steps=1000, learning_rate=1e-2): | ||
| self.model = model | ||
| self.num_steps = num_steps | ||
| self.learning_rate = learning_rate | ||
|
|
||
| def run(self, rng_key, observed_data, extra_args=None): | ||
| if extra_args is None: | ||
| extra_args = {} | ||
|
|
||
| guide = AutoDelta(self.model) | ||
| optimizer = numpyro.optim.Adam(step_size=self.learning_rate) | ||
| svi = SVI(self.model, guide, optimizer, loss=Trace_ELBO()) | ||
|
|
||
| svi_result = svi.run(rng_key, self.num_steps, observed_data=observed_data, **extra_args) | ||
|
|
||
| params = svi_result.params | ||
| # The params from AutoDelta are the MAP estimates (in unconstrained space usually, | ||
| # but AutoDelta returns constrained values if using `median` init or similar, | ||
| # actually AutoDelta parameters are the values themselves). | ||
|
|
||
| # We need to sample from the guide to get the values in the proper structure if needed, | ||
| # but for AutoDelta, params contains the values. | ||
| # Note: AutoDelta names parameters with `_auto_loc` suffix sometimes or keeps original names depending on version. | ||
| # Let's check the guide median. | ||
|
|
||
| return guide.median(params) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This parameters are for the data generation or the generative forward modelling?
Missing ellipticity prior information.