@@ -34,28 +34,27 @@ def forward(self, x):
3434 return x_res .squeeze (0 )
3535
3636
37-
3837def download_checkpoint_from_wandb (artifact_path , project_name = "ghost-irim" ):
3938 print (f"Downloading checkpoint from W&B: { artifact_path } " )
40-
39+
4140 wandb_api_key = os .environ .get ("WANDB_API_KEY" )
4241 if wandb_api_key :
4342 wandb .login (key = wandb_api_key )
44-
43+
4544 run = wandb .init (project = project_name , job_type = "inference" )
46-
45+
4746 artifact = run .use_artifact (artifact_path , type = "model" )
4847 artifact_dir = artifact .download ()
49-
48+
5049 artifact_path_obj = Path (artifact_dir )
5150 checkpoint_files = list (artifact_path_obj .glob ("*.ckpt" ))
52-
51+
5352 if not checkpoint_files :
5453 raise FileNotFoundError (f"No .ckpt file found in artifact directory: { artifact_dir } " )
55-
54+
5655 checkpoint_path = checkpoint_files [0 ]
5756 print (f"Checkpoint downloaded to: { checkpoint_path } " )
58-
57+
5958 return checkpoint_path
6059
6160
@@ -69,7 +68,7 @@ def main():
6968 else :
7069 device = "cpu"
7170 print (f"Using device: { device } " )
72-
71+
7372 model_name = config .model .name
7473 mask_size = config .inference .get ("mask_size" , 224 )
7574 image_size = 299 if model_name == "inception_v3" else 224
@@ -84,25 +83,19 @@ def main():
8483 test_data = dataset ["test" ]
8584 test_dataset = ForestDataset (test_data ["paths" ], test_data ["labels" ], transform = transforms )
8685
87- test_loader = torch .utils .data .DataLoader (
88- test_dataset , batch_size = 1 , shuffle = False , num_workers = 2
89- )
86+ test_loader = torch .utils .data .DataLoader (test_dataset , batch_size = 1 , shuffle = False , num_workers = 2 )
9087
9188 num_classes = len (label_map )
9289
9390 # =========================== MODEL LOADING ==================================== #
9491 wandb_artifact = config .inference .get ("wandb_artifact" , None )
95-
92+
9693 if wandb_artifact :
9794 wandb_project = config .inference .get ("wandb_project" , "ghost-irim" )
9895 checkpoint_path = download_checkpoint_from_wandb (wandb_artifact , wandb_project )
9996 else :
100- raise FileNotFoundError (
101- f"Checkpoint not found at { checkpoint_path } . "
102- "Please set 'wandb_artifact' in config.yaml to download from W&B, "
103- "or ensure the local checkpoint exists."
104- )
105-
97+ raise FileNotFoundError (f"Checkpoint not found at { checkpoint_path } . Please set 'wandb_artifact' in config.yaml to download from W&B, or ensure the local checkpoint exists." )
98+
10699 print (f"Loading model from: { checkpoint_path } " )
107100
108101 classifier = ClassifierModule .load_from_checkpoint (
@@ -116,15 +109,14 @@ def main():
116109 norm_std = [0.5 , 0.5 , 0.5 ]
117110
118111 seg_model = SegmentationWrapper (
119- classifier ,
112+ classifier ,
120113 mask_size = mask_size ,
121- mean = None , # TODO: fix
122- std = None , # TODO: fix
123- input_rescale = True # Expects 0-255 input, scales to 0-1 internally
114+ mean = None , # TODO: fix
115+ std = None , # TODO: fix
116+ input_rescale = True , # Expects 0-255 input, scales to 0-1 internally
124117 ).to (device )
125118 seg_model .eval ()
126119
127-
128120 # =========================== EXPORT TO ONNX =================================== #
129121 if config .inference .get ("export_onnx" , False ):
130122 dummy_input = torch .randn (1 , 3 , image_size , image_size , device = device )
@@ -140,25 +132,25 @@ def main():
140132 do_constant_folding = True ,
141133 )
142134 print (f"Exported model to { onnx_path .resolve ()} " )
143-
135+
144136 # Add metadata
145137 model_onnx = onnx .load (onnx_path )
146-
138+
147139 class_names = {v : k for k , v in label_map .items ()}
148-
140+
149141 def add_meta (key , value ):
150- meta = model_onnx .metadata_props .add ()
151- meta .key = key
152- meta .value = json .dumps (value )
142+ meta = model_onnx .metadata_props .add ()
143+ meta .key = key
144+ meta .value = json .dumps (value )
153145
154- add_meta (' model_type' , ' Segmentor' )
155- add_meta (' class_names' , class_names )
156- add_meta (' resolution' , 20 )
157- add_meta (' tiles_size' , image_size )
158- add_meta (' tiles_overlap' , 0 )
146+ add_meta (" model_type" , " Segmentor" )
147+ add_meta (" class_names" , class_names )
148+ add_meta (" resolution" , 20 )
149+ add_meta (" tiles_size" , image_size )
150+ add_meta (" tiles_overlap" , 0 )
159151
160152 onnx .save (model_onnx , onnx_path )
161-
153+
162154 if wandb .run is not None :
163155 onnx_artifact = wandb .Artifact (
164156 name = f"segmentation-model-{ model_name } " ,
@@ -170,29 +162,29 @@ def add_meta(key, value):
170162 "image_size" : image_size ,
171163 "format" : "onnx" ,
172164 "opset_version" : 17 ,
173- }
165+ },
174166 )
175167 onnx_artifact .add_file (str (onnx_path ))
176168 wandb .log_artifact (onnx_artifact )
177169 print (f"ONNX model uploaded to W&B artifacts as 'segmentation-model-{ model_name } '" )
178170 else :
179171 print ("Warning: W&B run not initialized. ONNX model not uploaded to artifacts." )
180-
172+
181173 # =========================== INFERENCE LOOP =================================== #
182174 print (f"Running inference on { len (test_loader )} samples..." )
183175 all_preds = []
184176 all_targets = []
185177
186178 with torch .no_grad ():
187- for i , batch in enumerate ( tqdm (test_loader ) ):
179+ for batch in tqdm (test_loader ):
188180 imgs , labels = batch
189181 imgs = imgs .to (device )
190182 labels = labels .to (device )
191183
192184 masks = seg_model (imgs )
193-
185+
194186 probs = masks [:, :, 0 , 0 ]
195-
187+
196188 all_preds .append (probs )
197189 all_targets .append (labels )
198190
@@ -202,7 +194,7 @@ def add_meta(key, value):
202194 # =========================== METRICS & LOGGING ================================ #
203195 if wandb .run is not None :
204196 print ("Calculating and logging metrics..." )
205-
197+
206198 metrics_per_experiment = count_metrics (all_targets , all_preds )
207199 print (f"Test Metrics: { metrics_per_experiment } " )
208200 for key , value in metrics_per_experiment .items ():
@@ -222,7 +214,7 @@ def add_meta(key, value):
222214
223215 plots_dir = Path ("src/plots" )
224216 plots_dir .mkdir (exist_ok = True , parents = True )
225-
217+
226218 get_confusion_matrix (all_preds , all_targets , class_names = list (label_map .keys ()))
227219 get_roc_auc_curve (all_preds , all_targets , class_names = list (label_map .keys ()))
228220 get_precision_recall_curve (all_preds , all_targets , class_names = list (label_map .keys ()))
@@ -234,5 +226,6 @@ def add_meta(key, value):
234226 else :
235227 print ("W&B run not active. Skipping metrics logging." )
236228
229+
237230if __name__ == "__main__" :
238231 main ()
0 commit comments