Skip to content

Commit 62ba551

Browse files
committed
Test accuracy of sign prediction in disp train test
1 parent 23f26bc commit 62ba551

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

src/ctapipe/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def disp_reconstructor_path(model_tmp_path, gamma_train_clf):
642642
from ctapipe.tools.train_disp_reconstructor import TrainDispReconstructor
643643

644644
out_file = model_tmp_path / "disp_reconstructor.pkl"
645+
cv_out_file = model_tmp_path / "cv_disp_reconstructor.h5"
645646
with FileLock(out_file.with_suffix(out_file.suffix + ".lock")):
646647
if out_file.is_file():
647648
return out_file
@@ -653,12 +654,13 @@ def disp_reconstructor_path(model_tmp_path, gamma_train_clf):
653654
argv=[
654655
f"--input={gamma_train_clf}",
655656
f"--output={out_file}",
657+
f"--cv-output={cv_out_file}",
656658
f"--config={config}",
657659
"--log-level=INFO",
658660
],
659661
)
660662
assert ret == 0
661-
return out_file
663+
return out_file, cv_out_file
662664

663665

664666
@pytest.fixture(scope="session")

src/ctapipe/tools/tests/test_apply_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def test_apply_all(
110110
):
111111
from ctapipe.tools.apply_models import ApplyModels
112112

113+
disp_reconstructor_path, _ = disp_reconstructor_path
114+
113115
input_path = get_dataset_path("gamma_diffuse_dl2_train_small.dl2.h5")
114116
output_path = tmp_path / "particle-and-energy-and-disp.dl2.h5"
115117

src/ctapipe/tools/tests/test_process_ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def test_process_apply_disp(
122122
):
123123
from ctapipe.tools.process import ProcessorTool
124124

125+
disp_reconstructor_path, _ = disp_reconstructor_path
126+
125127
output = tmp_path / "gamma_prod5.dl2_disp.h5"
126128

127129
config_path = tmp_path / "config.json"

src/ctapipe/tools/tests/test_train.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ctapipe.core import ToolConfigurationError, run_tool
55
from ctapipe.exceptions import TooFewEvents
6+
from ctapipe.io import read_table
67
from ctapipe.utils.datasets import resource_file
78

89

@@ -21,7 +22,15 @@ def test_train_particle_classifier(particle_classifier_path):
2122
def test_train_disp_reconstructor(disp_reconstructor_path):
2223
from ctapipe.reco import DispReconstructor
2324

24-
DispReconstructor.read(disp_reconstructor_path)
25+
model_path, cv_path = disp_reconstructor_path
26+
27+
DispReconstructor.read(model_path)
28+
29+
cv_table = read_table(cv_path, "/cv_predictions/LST_LST_LSTCam")
30+
disp = cv_table["disp_parameter"]
31+
true_disp = cv_table["truth"]
32+
accuracy = np.count_nonzero(np.sign(disp) == np.sign(true_disp)) / len(disp)
33+
assert accuracy > 0.75
2534

2635

2736
def test_too_few_events(tmp_path, dl2_shower_geometry_file):

0 commit comments

Comments
 (0)