33"""
44
55import pytest
6+ import shutil
7+
68from ctapipe .core import run_tool
79from ctapipe .utils import get_dataset_path
10+ from ctlearn .tools import TrainCTLearnModel
811
912@pytest .fixture (scope = "session" )
1013def gamma_simtel_path ():
@@ -33,13 +36,14 @@ def dl1_gamma_file(dl1_tmp_path, gamma_simtel_path):
3336 """
3437 from ctapipe .tools .process import ProcessorTool
3538
39+ allowed_tels = {7 , 13 , 15 , 16 , 17 , 19 }
3640 output = dl1_tmp_path / "gamma.dl1.h5"
37-
3841 argv = [
3942 f"--input={ gamma_simtel_path } " ,
4043 f"--output={ output } " ,
4144 "--write-images" ,
4245 "--SimTelEventSource.focal_length_choice=EQUIVALENT" ,
46+ f"--SimTelEventSource.allowed_tels={ allowed_tels } " ,
4347 ]
4448 assert run_tool (ProcessorTool (), argv = argv , cwd = dl1_tmp_path ) == 0
4549 return output
@@ -51,12 +55,14 @@ def dl1_proton_file(dl1_tmp_path, proton_simtel_path):
5155 """
5256 from ctapipe .tools .process import ProcessorTool
5357
58+ allowed_tels = {7 , 13 , 15 , 16 , 17 , 19 }
5459 output = dl1_tmp_path / "proton.dl1.h5"
5560 argv = [
5661 f"--input={ proton_simtel_path } " ,
5762 f"--output={ output } " ,
5863 "--write-images" ,
5964 "--SimTelEventSource.focal_length_choice=EQUIVALENT" ,
65+ f"--SimTelEventSource.allowed_tels={ allowed_tels } " ,
6066 ]
6167 assert run_tool (ProcessorTool (), argv = argv , cwd = dl1_tmp_path ) == 0
6268 return output
@@ -77,4 +83,56 @@ def r1_gamma_file(r1_tmp_path, gamma_simtel_path):
7783 "--SimTelEventSource.focal_length_choice=EQUIVALENT" ,
7884 ]
7985 assert run_tool (ProcessorTool (), argv = argv , cwd = r1_tmp_path ) == 0
80- return output
86+ return output
87+
88+ @pytest .fixture (scope = "session" )
89+ def ctlearn_trained_dl1_models (dl1_gamma_file , dl1_proton_file , tmp_path_factory ):
90+ """
91+ Test training CTLearn model using the DL1 gamma and proton files for all reconstruction tasks.
92+ Each test run gets its own isolated temp directories.
93+ """
94+ tmp_path = tmp_path_factory .mktemp ("ctlearn_models" )
95+
96+ # Temporary directories for signal and background
97+ signal_dir = tmp_path / "gamma_dl1"
98+ signal_dir .mkdir (parents = True , exist_ok = True )
99+
100+ background_dir = tmp_path / "proton_dl1"
101+ background_dir .mkdir (parents = True , exist_ok = True )
102+
103+ # Hardcopy DL1 gamma file to the signal directory
104+ shutil .copy (dl1_gamma_file , signal_dir )
105+ # Hardcopy DL1 proton file to the background directory
106+ shutil .copy (dl1_proton_file , background_dir )
107+
108+ ctlearn_trained_dl1_models = {}
109+ for reco_task in ["type" , "energy" , "cameradirection" ]:
110+ # Output directory for trained model
111+ output_dir = tmp_path / f"ctlearn_{ reco_task } "
112+
113+ # Build command-line arguments
114+ argv = [
115+ f"--signal={ signal_dir } " ,
116+ "--pattern-signal=*.dl1.h5" ,
117+ f"--output={ output_dir } " ,
118+ f"--reco={ reco_task } " ,
119+ "--TrainCTLearnModel.n_epochs=1" ,
120+ "--TrainCTLearnModel.batch_size=2" ,
121+ "--DLImageReader.focal_length_choice=EQUIVALENT" ,
122+ ]
123+
124+ # Include background only for classification task
125+ if reco_task == "type" :
126+ argv .extend ([
127+ f"--background={ background_dir } " ,
128+ "--pattern-background=*.dl1.h5" ,
129+ "--DLImageReader.enforce_subarray_equality=False" ,
130+ ])
131+
132+ # Run training
133+ assert run_tool (TrainCTLearnModel (), argv = argv , cwd = tmp_path ) == 0
134+
135+ ctlearn_trained_dl1_models [reco_task ] = output_dir / "ctlearn_model.keras"
136+ # Check that the trained model exists
137+ assert ctlearn_trained_dl1_models [reco_task ].exists ()
138+ return ctlearn_trained_dl1_models
0 commit comments