Skip to content

Commit 42d1ba6

Browse files
authored
Merge pull request #2738 from cta-observatory/fix_disp_sign
Fix wrong proba column being used for disp sign
2 parents ba264d9 + 62ba551 commit 42d1ba6

File tree

6 files changed

+19
-3
lines changed

6 files changed

+19
-3
lines changed

docs/changes/2738.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix inverted sign of ``DispReconstructor`` prediction.

src/ctapipe/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def disp_reconstructor_path(model_tmp_path, gamma_train_clf):
642642
from ctapipe.tools.train_disp_reconstructor import TrainDispReconstructor
643643

644644
out_file = model_tmp_path / "disp_reconstructor.pkl"
645+
cv_out_file = model_tmp_path / "cv_disp_reconstructor.h5"
645646
with FileLock(out_file.with_suffix(out_file.suffix + ".lock")):
646647
if out_file.is_file():
647648
return out_file
@@ -653,12 +654,13 @@ def disp_reconstructor_path(model_tmp_path, gamma_train_clf):
653654
argv=[
654655
f"--input={gamma_train_clf}",
655656
f"--output={out_file}",
657+
f"--cv-output={cv_out_file}",
656658
f"--config={config}",
657659
"--log-level=INFO",
658660
],
659661
)
660662
assert ret == 0
661-
return out_file
663+
return out_file, cv_out_file
662664

663665

664666
@pytest.fixture(scope="session")

src/ctapipe/reco/sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def _predict(self, key, table):
703703
else:
704704
prediction[valid] = valid_norms
705705

706-
sign_proba = self._models[key][1].predict_proba(X)[:, 0]
706+
sign_proba = self._models[key][1].predict_proba(X)[:, 1]
707707
# proba is [0 and 1] where 0 => very certain -1, 1 => very certain 1
708708
# and 0.5 means random guessing either. So we transform to a score
709709
# where 0 means "guessing" and 1 means "very certain"

src/ctapipe/tools/tests/test_apply_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def test_apply_all(
110110
):
111111
from ctapipe.tools.apply_models import ApplyModels
112112

113+
disp_reconstructor_path, _ = disp_reconstructor_path
114+
113115
input_path = get_dataset_path("gamma_diffuse_dl2_train_small.dl2.h5")
114116
output_path = tmp_path / "particle-and-energy-and-disp.dl2.h5"
115117

src/ctapipe/tools/tests/test_process_ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def test_process_apply_disp(
122122
):
123123
from ctapipe.tools.process import ProcessorTool
124124

125+
disp_reconstructor_path, _ = disp_reconstructor_path
126+
125127
output = tmp_path / "gamma_prod5.dl2_disp.h5"
126128

127129
config_path = tmp_path / "config.json"

src/ctapipe/tools/tests/test_train.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ctapipe.core import ToolConfigurationError, run_tool
55
from ctapipe.exceptions import TooFewEvents
6+
from ctapipe.io import read_table
67
from ctapipe.utils.datasets import resource_file
78

89

@@ -21,7 +22,15 @@ def test_train_particle_classifier(particle_classifier_path):
2122
def test_train_disp_reconstructor(disp_reconstructor_path):
2223
from ctapipe.reco import DispReconstructor
2324

24-
DispReconstructor.read(disp_reconstructor_path)
25+
model_path, cv_path = disp_reconstructor_path
26+
27+
DispReconstructor.read(model_path)
28+
29+
cv_table = read_table(cv_path, "/cv_predictions/LST_LST_LSTCam")
30+
disp = cv_table["disp_parameter"]
31+
true_disp = cv_table["truth"]
32+
accuracy = np.count_nonzero(np.sign(disp) == np.sign(true_disp)) / len(disp)
33+
assert accuracy > 0.75
2534

2635

2736
def test_too_few_events(tmp_path, dl2_shower_geometry_file):

0 commit comments

Comments
 (0)