Skip to content

Commit 997b322

Browse files
authored
Merge pull request #85 from CosmoStat/feature_unit_tests_run_configs
Feature unit tests run configs - MetricsConfigHandler Class
2 parents 3dae190 + 52214a0 commit 997b322

File tree

11 files changed

+514
-87
lines changed

11 files changed

+514
-87
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ htmlcov/
6262
.coverage.*
6363
.cache
6464
nosetests.xml
65+
pytest.xml
6566
coverage.xml
6667
*.cover
6768
*.py,cover

src/wf_psf/psf_models/psf_models.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorflow.python.keras.engine import data_adapter
1313
from wf_psf.utils.utils import PI_zernikes, zernike_generator
1414
from wf_psf.sims.SimPSFToolkit import SimPSFToolkit
15+
import glob
1516
from sys import exit
1617
import logging
1718

@@ -21,7 +22,13 @@
2122

2223

2324
class PsfModelError(Exception):
24-
pass
25+
"""PSF Model Parameter Error exception class for specific error scenarios."""
26+
27+
def __init__(
28+
self, message="An error with your PSF model parameter settings occurred."
29+
):
30+
self.message = message
31+
super().__init__(self.message)
2532

2633

2734
def register_psfclass(psf_class):
@@ -67,10 +74,9 @@ def set_psf_model(model_name):
6774

6875
try:
6976
psf_class = PSF_CLASS[model_name]
70-
except KeyError:
71-
logger.exception("PSF model entered is invalid. Check your config settings.")
72-
exit()
73-
77+
except KeyError as e:
78+
logger.exception(e)
79+
raise PsfModelError("PSF model entered is invalid. Check your config settings.")
7480
return psf_class
7581

7682

@@ -102,6 +108,31 @@ def get_psf_model(model_params, training_hparams, *coeff_matrix):
102108
return psf_class(model_params, training_hparams, *coeff_matrix)
103109

104110

111+
def get_psf_model_weights_filepath(weights_filepath):
112+
"""Get PSF model weights filepath.
113+
114+
A function to return the basename of the user-specified psf model weights path.
115+
116+
Parameters
117+
----------
118+
weights_filepath: str
119+
Basename of the psf model weights to be loaded.
120+
121+
Returns
122+
-------
123+
str
124+
The absolute path concatenated to the basename of the psf model weights to be loaded.
125+
126+
"""
127+
try:
128+
return glob.glob(weights_filepath)[0].split(".")[0]
129+
except IndexError:
130+
logger.exception(
131+
"PSF weights file not found. Check that you've specified the correct weights file in the metrics config file."
132+
)
133+
raise PsfModelError("PSF model weights error.")
134+
135+
105136
def tf_zernike_cube(n_zernikes, pupil_diam):
106137
"""Tensor Flow Zernike Cube.
107138
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
---
2+
metrics_conf: metrics_config.yaml
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
metrics:
2+
# Specify the type of model weights to load by entering "psf_model" to load weights of final psf model or "checkpoint" to load weights from a checkpoint callback.
3+
model_save_path: checkpoint
4+
# Choose the training cycle for which to evaluate the psf_model. Can be: 1, 2, ...
5+
saved_training_cycle: 2
6+
# Metrics-only run: Specify model_params for a pre-trained model else leave blank if running training + metrics
7+
# Specify path to Parent Directory of Trained Model
8+
trained_model_path: src/wf_psf/tests/data/validation/main_random_seed
9+
# Name of the Trained Model Config file stored in config sub-directory in the trained_model_path parent directory
10+
trained_model_config: training_config.yaml
11+
#Evaluate the monchromatic RMSE metric.
12+
eval_mono_metric_rmse: True
13+
#Evaluate the OPD RMSE metric.
14+
eval_opd_metric_rmse: True
15+
#Evaluate the super-resolution and the shape RMSE metrics for the train dataset.
16+
eval_train_shape_sr_metric_rmse: True
17+
# Name of Plotting Config file - Enter name of yaml file to run plot metrics else if empty run metrics evaluation only
18+
plotting_config: <enter name of plotting_config .yaml file or leave empty>
19+
ground_truth_model:
20+
model_params:
21+
#Model used as ground truth for the evaluation. Options are: 'poly' for polychromatic and 'physical' [not available].
22+
model_name: poly
23+
24+
# Evaluation parameters
25+
#Number of bins used for the ground truth model poly PSF generation
26+
n_bins_lda: 20
27+
28+
#Downsampling rate to match the oversampled model to the specified telescope's sampling.
29+
output_Q: 3
30+
31+
#Oversampling rate used for the OPD/WFE PSF model.
32+
oversampling_rate: 3
33+
34+
#Dimension of the pixel PSF postage stamp
35+
output_dim: 32
36+
37+
#Dimension of the OPD/Wavefront space."
38+
pupil_diameter: 256
39+
40+
#Boolean to define if we use sample weights based on the noise standard deviation estimation
41+
use_sample_weights: True
42+
43+
#Interpolation type for the physical poly model. Options are: 'none', 'all', 'top_K', 'independent_Zk'."
44+
interpolation_type: None
45+
46+
# SED intepolation points per bin
47+
sed_interp_pts_per_bin: 0
48+
49+
# SED extrapolate
50+
sed_extrapolate: True
51+
52+
# SED interpolate kind
53+
sed_interp_kind: linear
54+
55+
# Standard deviation of the multiplicative SED Gaussian noise.
56+
sed_sigma: 0
57+
58+
#Limits of the PSF field coordinates for the x axis.
59+
x_lims: [0.0, 1.0e+3]
60+
61+
#Limits of the PSF field coordinates for the y axis.
62+
y_lims: [0.0, 1.0e+3]
63+
64+
# Hyperparameters for Parametric model
65+
param_hparams:
66+
# Random seed for Tensor Flow Initialization
67+
random_seed: 3877572
68+
69+
# Parameter for the l2 loss function for the Optical path differences (OPD)/WFE
70+
l2_param: 0.
71+
72+
#Zernike polynomial modes to use on the parametric part.
73+
n_zernikes: 45
74+
75+
#Max polynomial degree of the parametric part.
76+
d_max: 2
77+
78+
#Flag to save optimisation history for parametric model
79+
save_optim_history_param: true
80+
81+
# Hyperparameters for non-parametric model
82+
nonparam_hparams:
83+
#Max polynomial degree of the non-parametric part.
84+
d_max_nonparam: 5
85+
86+
# Number of graph features
87+
num_graph_features: 10
88+
89+
#L1 regularisation parameter for the non-parametric part."
90+
l1_rate: 1.0e-8
91+
92+
#Flag to enable Projected learning for DD_features to be used with `poly` or `semiparametric` model.
93+
project_dd_features: False
94+
95+
#Flag to reset DD_features to be used with `poly` or `semiparametric` model
96+
reset_dd_features: False
97+
98+
#Flag to save optimisation history for non-parametric model
99+
save_optim_history_nonparam: True
100+
101+
metrics_hparams:
102+
# Batch size to use for the evaluation.
103+
batch_size: 16
104+
105+
#Save RMS error for each super resolved PSF in the test dataset in addition to the mean across the FOV."
106+
#Flag to get Super-Resolution pixel PSF RMSE for each individual test star.
107+
#If `True`, the relative pixel RMSE of each star is added to ther saving dictionary.
108+
opt_stars_rel_pix_rmse: False
109+
110+
## Specific parameters
111+
# Parameter for the l2 loss of the OPD.
112+
l2_param: 0.
113+
114+
## Define the resolution at which you'd like to measure the shape of the PSFs
115+
#Downsampling rate from the high-resolution pixel modelling space.
116+
# Recommended value: 1
117+
output_Q: 1
118+
119+
#Dimension of the pixel PSF postage stamp; it should be big enough so that most of the signal is contained inside the postage stamp.
120+
# It also depends on the Q values used.
121+
# Recommended value: 64 or higher
122+
output_dim: 64
123+

src/wf_psf/tests/data/validation/main_random_seed/config/training_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
training:
22
# ID name
3-
id_name: _validation
3+
id_name: _sample_w_bis1_2k
44
# Name of Data Config file
55
data_config: data_config.yaml
66
# Metrics Config file - Enter file to run metrics evaluation else if empty run train only
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
training:
2+
# ID name
3+
id_name: _errorsample_w_bis1_2k
4+
# Name of Data Config file
5+
data_config: data_config.yaml
6+
# Metrics Config file - Enter file to run metrics evaluation else if empty run train only
7+
metrics_config: metrics_config.yaml
8+
model_params:
9+
# Model type. Options are: 'mccd', 'graph', 'poly, 'param', 'poly_physical'."
10+
model_name: poly
11+
12+
#Num of wavelength bins to reconstruct polychromatic objects.
13+
n_bins_lda: 8
14+
15+
#Downsampling rate to match the oversampled model to the specified telescope's sampling.
16+
output_Q: 3
17+
18+
#Oversampling rate used for the OPD/WFE PSF model.
19+
oversampling_rate: 3
20+
21+
#Dimension of the pixel PSF postage stamp
22+
output_dim: 32
23+
24+
#Dimension of the OPD/Wavefront space."
25+
pupil_diameter: 256
26+
27+
#Boolean to define if we use sample weights based on the noise standard deviation estimation
28+
use_sample_weights: True
29+
30+
#Interpolation type for the physical poly model. Options are: 'none', 'all', 'top_K', 'independent_Zk'."
31+
interpolation_type: None
32+
33+
# SED intepolation points per bin
34+
sed_interp_pts_per_bin: 0
35+
36+
# SED extrapolate
37+
sed_extrapolate: True
38+
39+
# SED interpolate kind
40+
sed_interp_kind: linear
41+
42+
# Standard deviation of the multiplicative SED Gaussian noise.
43+
sed_sigma: 0
44+
45+
#Limits of the PSF field coordinates for the x axis.
46+
x_lims: [0.0, 1.0e+3]
47+
48+
#Limits of the PSF field coordinates for the y axis.
49+
y_lims: [0.0, 1.0e+3]
50+
51+
# Hyperparameters for Parametric model
52+
param_hparams:
53+
# Random seed for Tensor Flow Initialization
54+
random_seed: 3877572
55+
56+
# Parameter for the l2 loss function for the Optical path differences (OPD)/WFE
57+
l2_param: 0.
58+
59+
#Zernike polynomial modes to use on the parametric part.
60+
n_zernikes: 15
61+
62+
#Max polynomial degree of the parametric part. chg to max_deg_param
63+
d_max: 2
64+
65+
#Flag to save optimisation history for parametric model
66+
save_optim_history_param: true
67+
68+
# Hyperparameters for non-parametric model
69+
nonparam_hparams:
70+
71+
#Max polynomial degree of the non-parametric part. chg to max_deg_nonparam
72+
d_max_nonparam: 5
73+
74+
# Number of graph features
75+
num_graph_features: 10
76+
77+
#L1 regularisation parameter for the non-parametric part."
78+
l1_rate: 1.0e-8
79+
80+
#Flag to enable Projected learning for DD_features to be used with `poly` or `semiparametric` model.
81+
project_dd_features: False
82+
83+
#Flag to reset DD_features to be used with `poly` or `semiparametric` model
84+
reset_dd_features: False
85+
86+
#Flag to save optimisation history for non-parametric model
87+
save_optim_history_nonparam: true
88+
89+
# Training hyperparameters
90+
training_hparams:
91+
n_epochs_params: [2, 2, 2]
92+
93+
n_epochs_non_params: [2, 2, 2]
94+
95+
batch_size: 32
96+
97+
multi_cycle_params:
98+
99+
# Total amount of cycles to perform.
100+
total_cycles: 2
101+
102+
# Train cycle definition. It can be: 'parametric', 'non-parametric', 'complete', 'only-non-parametric' and 'only-parametric'."
103+
cycle_def: complete
104+
105+
# Make checkpoint at every cycle or just save the checkpoint at the end of the training."
106+
save_all_cycles: True
107+
108+
#"Saved cycle to use for the evaluation. Can be 'cycle1', 'cycle2', ..."
109+
saved_cycle: cycle2
110+
111+
# Learning rates for the parametric parts. It should be a str where numeric values are separated by spaces.
112+
learning_rate_params: [1.0e-2, 1.0e-2]
113+
114+
# Learning rates for the non-parametric parts. It should be a str where numeric values are separated by spaces."
115+
learning_rate_non_params: [1.0e-1, 1.0e-1]
116+
117+
# Number of training epochs of the parametric parts. It should be a strign where numeric values are separated by spaces."
118+
n_epochs_params: [20, 20]
119+
120+
# Number of training epochs of the non-parametric parts. It should be a str where numeric values are separated by spaces."
121+
n_epochs_non_params: [100, 120]
122+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""UNIT TESTS FOR PACKAGE MODULE: PSF MODELS.
2+
3+
This module contains unit tests for the wf_psf.psf_models psf_models module.
4+
5+
:Author: Jennifer Pollack <[email protected]>
6+
7+
8+
"""
9+
10+
import pytest
11+
from wf_psf.psf_models import psf_models
12+
from wf_psf.utils.io import FileIOHandler
13+
import os
14+
15+
16+
def test_get_psf_model_weights_filepath():
17+
weights_filepath = "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*"
18+
19+
ans = psf_models.get_psf_model_weights_filepath(weights_filepath)
20+
assert (
21+
ans
22+
== "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint_callback_poly_sample_w_bis1_2k_cycle2"
23+
)

0 commit comments

Comments
 (0)