Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,6 @@ marimo/_static/
marimo/_lsp/
__marimo__/

shine/_version.py
shine/_version.py

results/
20 changes: 10 additions & 10 deletions DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,26 @@ SHINE treats shear measurement as a Bayesian inverse problem. Instead of measuri
```mermaid
graph TD
Config[YAML Configuration] --> Loader[Config Handler]
Loader --> Scene[Scene Modelling (NumPyro + JAX-GalSim)]
Loader --> Scene[Scene Modelling<br/>NumPyro + JAX-GalSim]

subgraph "Forward Model (JAX)"
Priors[Priors (NumPyro)] --> Scene
Scene --> Galaxy[Galaxy Generation (Sersic/Morphology)]
Priors[Priors<br/>NumPyro] --> Scene
Scene --> Galaxy[Galaxy Generation<br/>Sersic/Morphology]
Scene --> PSF[PSF Modelling]
Galaxy --> Convolve[Convolution]
PSF --> Convolve
Convolve --> Noise[Noise Model]
Noise --> ModelImage[Simulated Image]
end
Data[Observed Data (Fits/HDF5)] --> Likelihood

Data[Observed Data<br/>Fits/HDF5] --> Likelihood
ModelImage --> Likelihood[Likelihood Evaluation]
Likelihood --> Inference[Inference Engine (NumPyro/BlackJAX)]

Likelihood --> Inference[Inference Engine<br/>NumPyro/BlackJAX]
Inference --> Posterior[Shear Posterior]

subgraph "Workflow Management"
WMS[WMS (Slurm/Cluster)] --> Config
WMS[WMS<br/>Slurm/Cluster] --> Config
WMS --> Data
end
```
Expand Down
38 changes: 38 additions & 0 deletions configs/test_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
image:
pixel_scale: 0.2
size_x: 32
size_y: 32
n_objects: 1
noise:
type: Gaussian
sigma: 0.01

psf:
type: Gaussian
sigma: 0.1

gal:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameters are for the data generation or the generative forward modelling?

Missing ellipticity prior information.

type: Exponential
flux:
type: LogNormal
mean: 100.0
sigma: 0.1
half_light_radius:
type: Uniform
min: 0.3
max: 0.8
shear:
type: G1G2
g1:
type: Normal
mean: 0.05
sigma: 0.02
g2:
type: Normal
mean: -0.05
sigma: 0.02

inference:
warmup: 100
samples: 100
chains: 1
83 changes: 83 additions & 0 deletions shine/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Union, Optional, Dict, Any, List
from pydantic import BaseModel, Field, validator
import yaml
from pathlib import Path

# --- Distribution Models (for Priors) ---

class DistributionConfig(BaseModel):
type: str
mean: Optional[float] = None
sigma: Optional[float] = None
min: Optional[float] = None
max: Optional[float] = None

# Allow extra fields for other distributions
class Config:
extra = "allow"

# --- Component Models ---

class NoiseConfig(BaseModel):
type: str = "Gaussian"
sigma: float

class ImageConfig(BaseModel):
pixel_scale: float
size_x: int
size_y: int
n_objects: int = 1 # Default to 1 for simple tests
noise: NoiseConfig

class PSFConfig(BaseModel):
type: str = "Gaussian"
sigma: float
beta: Optional[float] = 2.5 # For Moffat

class ShearComponentConfig(BaseModel):
# Can be a fixed float or a distribution
type: Optional[str] = None # If None, assume fixed value in parent or handled elsewhere
mean: Optional[float] = 0.0
sigma: Optional[float] = 0.05

# To handle the case where it's just a float in YAML, we might need a custom validator
# but for now let's assume structured input as per design doc

class ShearConfig(BaseModel):
type: str = "G1G2"
g1: Union[float, DistributionConfig]
g2: Union[float, DistributionConfig]

class GalaxyConfig(BaseModel):
type: str = "Exponential" # Changed default from Sersic to Exponential
n: Optional[Union[float, DistributionConfig]] = None # Make optional for Exponential
flux: Union[float, DistributionConfig]
half_light_radius: Union[float, DistributionConfig] = Field(..., alias="half_light_radius")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does Field do? If it is for adding an alias to half_light_radius shouldn't the alias be hlr? And why is Field only used here?

shear: ShearConfig

class InferenceConfig(BaseModel):
warmup: int = 500
samples: int = 1000
chains: int = 1
dense_mass: bool = False

class ShineConfig(BaseModel):
image: ImageConfig
psf: PSFConfig
gal: GalaxyConfig
inference: InferenceConfig = Field(default_factory=InferenceConfig)
data_path: Optional[str] = None
output_path: str = "results"

class ConfigHandler:
@staticmethod
def load(path: str) -> ShineConfig:
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")

with open(path, "r") as f:
data = yaml.safe_load(f)

# Basic validation and type conversion via Pydantic
return ShineConfig(**data)
83 changes: 83 additions & 0 deletions shine/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import jax.numpy as jnp
import jax
from dataclasses import dataclass
from typing import Optional, Any, Dict
from shine.config import ShineConfig

@dataclass
class Observation:
image: jnp.ndarray
noise_map: jnp.ndarray
psf_config: Dict[str, Any] # Store PSF config instead of object

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using the Class PSFConfig defined in config.py.

wcs: Any = None

class DataLoader:
@staticmethod
def load(config: ShineConfig) -> Observation:
if config.data_path and config.data_path != "None":
# TODO: Implement real data loading (Fits/HDF5)
raise NotImplementedError("Real data loading not yet implemented. Use synthetic generation.")
else:
print("No data path provided. Generating synthetic data...")
return DataLoader.generate_synthetic(config)

@staticmethod
def generate_synthetic(config: ShineConfig) -> Observation:
import galsim

# 1. Define PSF

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually this will be a single line:

psf = psf_utils.get_psf(config.psf)

and all the code should be written in the psf_utils sub package/module.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same applies to the following lines.

if config.psf.type == "Gaussian":
psf = galsim.Gaussian(sigma=config.psf.sigma)
else:
raise NotImplementedError(f"PSF type {config.psf.type} not supported for synthetic gen")

# 2. Define Galaxy (using mean values from config for "truth")
def get_mean(param):
if isinstance(param, (float, int)):
return float(param)
# If it's a distribution config
if param.mean is not None:
return param.mean
# Handle Uniform
if param.type == 'Uniform' and param.min is not None and param.max is not None:
return (param.min + param.max) / 2.0
return param.mean # Fallback (might still be None if not handled)

gal_flux = get_mean(config.gal.flux)
gal_hlr = get_mean(config.gal.half_light_radius)

# Shear
g1 = get_mean(config.gal.shear.g1)
g2 = get_mean(config.gal.shear.g2)
shear = galsim.Shear(g1=g1, g2=g2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing intrinsic ellipticity.

# Create Galaxy Object - Use Exponential (Sersic n=1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle if we follow the previous logic it should check the galaxy type:

if config.gal.type == "exponential":
...

gal = galsim.Exponential(half_light_radius=gal_hlr, flux=gal_flux)
gal = gal.shear(shear)

# Convolve
final = galsim.Convolve([gal, psf])

# 3. Draw Image
image = final.drawImage(nx=config.image.size_x,
ny=config.image.size_y,
scale=config.image.pixel_scale).array

# 4. Add Noise
rng = galsim.BaseDeviate(0)
noise_sigma = config.image.noise.sigma
noise = galsim.GaussianNoise(rng, sigma=noise_sigma)

# GalSim image for noise addition
gs_image = galsim.Image(image)
gs_image.addNoise(noise)
noisy_image = gs_image.array

noise_map = jnp.ones_like(noisy_image) * (noise_sigma**2)

# Return JAX arrays and PSF config
return Observation(
image=jnp.array(noisy_image),
noise_map=noise_map,
psf_config={'type': config.psf.type, 'sigma': config.psf.sigma}
)
55 changes: 55 additions & 0 deletions shine/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import jax
import numpyro
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from typing import Dict, Any
import arviz as az

class HMCInference:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be named simply Inference because it could implement various inference methods (HMC, MCMC, etc).

def __init__(self, model, num_warmup=500, num_samples=1000, num_chains=1, dense_mass=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't the class InferenceConfig used here?

self.model = model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear what type of object model is.
From scene.py I see model is a function (the loglike function) returned by build_model.

self.num_warmup = num_warmup
self.num_samples = num_samples
self.num_chains = num_chains
self.dense_mass = dense_mass

def run(self, rng_key, observed_data, extra_args=None):
if extra_args is None:
extra_args = {}

kernel = NUTS(self.model, dense_mass=self.dense_mass)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do NUTS and MCMC deal with initial samples?

mcmc = MCMC(kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=self.num_chains)

mcmc.run(rng_key, observed_data=observed_data, **extra_args)
mcmc.print_summary()

# Convert to ArviZ InferenceData
return az.from_numpyro(mcmc)

class MAPInference:
def __init__(self, model, num_steps=1000, learning_rate=1e-2):
self.model = model
self.num_steps = num_steps
self.learning_rate = learning_rate

def run(self, rng_key, observed_data, extra_args=None):
if extra_args is None:
extra_args = {}

guide = AutoDelta(self.model)
optimizer = numpyro.optim.Adam(step_size=self.learning_rate)
svi = SVI(self.model, guide, optimizer, loss=Trace_ELBO())

svi_result = svi.run(rng_key, self.num_steps, observed_data=observed_data, **extra_args)

params = svi_result.params
# The params from AutoDelta are the MAP estimates (in unconstrained space usually,
# but AutoDelta returns constrained values if using `median` init or similar,
# actually AutoDelta parameters are the values themselves).

# We need to sample from the guide to get the values in the proper structure if needed,
# but for AutoDelta, params contains the values.
# Note: AutoDelta names parameters with `_auto_loc` suffix sometimes or keeps original names depending on version.
# Let's check the guide median.

return guide.median(params)
Loading