Skip to content

Commit b2a5f6b

Browse files
committed
add test for prediction tool
1 parent b4d2269 commit b2a5f6b

File tree

3 files changed

+137
-6
lines changed

3 files changed

+137
-6
lines changed

ctlearn/conftest.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
"""
44

55
import pytest
6+
import shutil
7+
68
from ctapipe.core import run_tool
79
from ctapipe.utils import get_dataset_path
10+
from ctlearn.tools import TrainCTLearnModel
811

912
@pytest.fixture(scope="session")
1013
def gamma_simtel_path():
@@ -33,13 +36,14 @@ def dl1_gamma_file(dl1_tmp_path, gamma_simtel_path):
3336
"""
3437
from ctapipe.tools.process import ProcessorTool
3538

39+
allowed_tels = {7, 13, 15, 16, 17, 19}
3640
output = dl1_tmp_path / "gamma.dl1.h5"
37-
3841
argv = [
3942
f"--input={gamma_simtel_path}",
4043
f"--output={output}",
4144
"--write-images",
4245
"--SimTelEventSource.focal_length_choice=EQUIVALENT",
46+
f"--SimTelEventSource.allowed_tels={allowed_tels}",
4347
]
4448
assert run_tool(ProcessorTool(), argv=argv, cwd=dl1_tmp_path) == 0
4549
return output
@@ -51,12 +55,14 @@ def dl1_proton_file(dl1_tmp_path, proton_simtel_path):
5155
"""
5256
from ctapipe.tools.process import ProcessorTool
5357

58+
allowed_tels = {7, 13, 15, 16, 17, 19}
5459
output = dl1_tmp_path / "proton.dl1.h5"
5560
argv = [
5661
f"--input={proton_simtel_path}",
5762
f"--output={output}",
5863
"--write-images",
5964
"--SimTelEventSource.focal_length_choice=EQUIVALENT",
65+
f"--SimTelEventSource.allowed_tels={allowed_tels}",
6066
]
6167
assert run_tool(ProcessorTool(), argv=argv, cwd=dl1_tmp_path) == 0
6268
return output
@@ -77,4 +83,56 @@ def r1_gamma_file(r1_tmp_path, gamma_simtel_path):
7783
"--SimTelEventSource.focal_length_choice=EQUIVALENT",
7884
]
7985
assert run_tool(ProcessorTool(), argv=argv, cwd=r1_tmp_path) == 0
80-
return output
86+
return output
87+
88+
@pytest.fixture(scope="session")
89+
def ctlearn_trained_dl1_models(dl1_gamma_file, dl1_proton_file, tmp_path_factory):
90+
"""
91+
Test training CTLearn model using the DL1 gamma and proton files for all reconstruction tasks.
92+
Each test run gets its own isolated temp directories.
93+
"""
94+
tmp_path = tmp_path_factory.mktemp("ctlearn_models")
95+
96+
# Temporary directories for signal and background
97+
signal_dir = tmp_path / "gamma_dl1"
98+
signal_dir.mkdir(parents=True, exist_ok=True)
99+
100+
background_dir = tmp_path / "proton_dl1"
101+
background_dir.mkdir(parents=True, exist_ok=True)
102+
103+
# Hardcopy DL1 gamma file to the signal directory
104+
shutil.copy(dl1_gamma_file, signal_dir)
105+
# Hardcopy DL1 proton file to the background directory
106+
shutil.copy(dl1_proton_file, background_dir)
107+
108+
ctlearn_trained_dl1_models = {}
109+
for reco_task in ["type", "energy", "cameradirection"]:
110+
# Output directory for trained model
111+
output_dir = tmp_path / f"ctlearn_{reco_task}"
112+
113+
# Build command-line arguments
114+
argv = [
115+
f"--signal={signal_dir}",
116+
"--pattern-signal=*.dl1.h5",
117+
f"--output={output_dir}",
118+
f"--reco={reco_task}",
119+
"--TrainCTLearnModel.n_epochs=1",
120+
"--TrainCTLearnModel.batch_size=2",
121+
"--DLImageReader.focal_length_choice=EQUIVALENT",
122+
]
123+
124+
# Include background only for classification task
125+
if reco_task == "type":
126+
argv.extend([
127+
f"--background={background_dir}",
128+
"--pattern-background=*.dl1.h5",
129+
"--DLImageReader.enforce_subarray_equality=False",
130+
])
131+
132+
# Run training
133+
assert run_tool(TrainCTLearnModel(), argv=argv, cwd=tmp_path) == 0
134+
135+
ctlearn_trained_dl1_models[reco_task] = output_dir / "ctlearn_model.keras"
136+
# Check that the trained model exists
137+
assert ctlearn_trained_dl1_models[reco_task].exists()
138+
return ctlearn_trained_dl1_models
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import shutil
2+
import numpy as np
3+
4+
from ctapipe.core import run_tool
5+
from ctapipe.io import TableLoader
6+
from ctlearn.tools import MonoPredictCTLearnModel
7+
8+
def test_predict_model(ctlearn_trained_dl1_models, dl1_gamma_file, tmp_path):
9+
"""
10+
Test training CTLearn model using the DL1 gamma and proton files for all reconstruction tasks.
11+
Each test run gets its own isolated temp directories.
12+
"""
13+
14+
model_dir = tmp_path / "trained_models"
15+
model_dir.mkdir(parents=True, exist_ok=True)
16+
17+
dl2_dir = tmp_path / "dl2_output"
18+
dl2_dir.mkdir(parents=True, exist_ok=True)
19+
20+
# Hardcopy the trained models to the model directory
21+
for reco_task in ["type", "energy", "cameradirection"]:
22+
shutil.copy(ctlearn_trained_dl1_models[f"{reco_task}"], model_dir / f"ctlearn_model_{reco_task}.keras")
23+
model_file = model_dir / f"ctlearn_model_{reco_task}.keras"
24+
assert model_file.exists(), f"Trained model file not found for {reco_task}"
25+
26+
# Build command-line arguments
27+
output_file = dl2_dir / "gamma.dl2.h5"
28+
argv = [
29+
f"--input_url={dl1_gamma_file}",
30+
f"--output={output_file}",
31+
"--PredictCTLearnModel.batch_size=4",
32+
"--DLImageReader.focal_length_choice=EQUIVALENT",
33+
]
34+
35+
# Run Prediction for energy and type together
36+
assert run_tool(
37+
MonoPredictCTLearnModel(),
38+
argv = argv + [
39+
f"--PredictCTLearnModel.load_type_model_from={model_dir}/ctlearn_model_type.keras",
40+
f"--PredictCTLearnModel.load_energy_model_from={model_dir}/ctlearn_model_energy.keras",
41+
"--use-HDF5Merger",
42+
"--no-dl1-images",
43+
"--no-true-images",
44+
],
45+
cwd=tmp_path
46+
) == 0
47+
48+
assert run_tool(
49+
MonoPredictCTLearnModel(),
50+
argv= argv + [
51+
f"--PredictCTLearnModel.load_cameradirection_model_from="
52+
f"{model_dir}/ctlearn_model_cameradirection.keras",
53+
"--no-use-HDF5Merger",
54+
],
55+
cwd=tmp_path,
56+
) == 0
57+
58+
59+
allowed_tels = [7, 13, 15, 16, 17, 19]
60+
required_columns = [
61+
"telescope_pointing_azimuth",
62+
"telescope_pointing_altitude",
63+
"CTLearn_alt",
64+
"CTLearn_az",
65+
"CTLearn_prediction",
66+
"CTLearn_energy",
67+
]
68+
# Check that the output DL2 file was created
69+
assert output_file.exists(), "Output DL2 file not created"
70+
# Check that the created DL2 file can be read with the TableLoader
71+
with TableLoader(output_file, pointing=True, focal_length_choice="EQUIVALENT") as loader:
72+
events = loader.read_telescope_events_by_id(telescopes=allowed_tels)
73+
for tel_id in allowed_tels:
74+
assert len(events[tel_id]) > 0
75+
for col in required_columns:
76+
assert col in events[tel_id].colnames, f"{col} missing in DL2 file {output_file.name}"
77+
assert events[tel_id][col][0] is not np.nan, f"{col} has NaN values in DL2 file {output_file.name}"

ctlearn/tools/tests/test_train_model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,11 @@ def test_train_ctlearn_model(reco_task, dl1_gamma_file, dl1_proton_file, tmp_pat
3939

4040
# Include background only for classification task
4141
if reco_task == "type":
42-
allowed_tels = {7, 13, 15, 16, 17, 19}
4342
argv.extend([
4443
f"--background={background_dir}",
4544
"--pattern-background=*.dl1.h5",
46-
f"--DLImageReader.allowed_tels={allowed_tels}",
4745
"--DLImageReader.enforce_subarray_equality=False",
4846
])
49-
else:
50-
argv.extend(["--DLImageReader.allowed_tel_types=LST_LST_LSTCam"])
5147

5248
# Run training
5349
assert run_tool(TrainCTLearnModel(), argv=argv, cwd=tmp_path) == 0

0 commit comments

Comments
 (0)