File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -416,7 +416,7 @@ def compute_shape_metrics(
416416 predictions = tf_semiparam_field .predict (x = pred_inputs , batch_size = batch_size )
417417
418418 # GT data preparation
419- if dataset_dict is None or 'super_res_stars' not in dataset_dict :
419+ if dataset_dict is None or 'super_res_stars' not in dataset_dict or 'SR_stars' not in dataset_dict :
420420 print ('Generating GT super resolved stars from the GT model.' )
421421 # Change interpolation parameters for the GT simPSF
422422 interp_pts_per_bin = simPSF_np .interp_pts_per_bin
@@ -438,7 +438,10 @@ def compute_shape_metrics(
438438
439439 else :
440440 print ('Using super resolved stars from dataset.' )
441- GT_predictions = dataset_dict ['super_res_stars' ]
441+ if 'super_res_stars' in dataset_dict :
442+ GT_predictions = dataset_dict ['super_res_stars' ]
443+ elif 'SR_stars' in dataset_dict :
444+ GT_predictions = dataset_dict ['SR_stars' ]
442445
443446 # Calculate residuals
444447 residuals = np .sqrt (np .mean ((GT_predictions - predictions )** 2 , axis = (1 , 2 )))
You can’t perform that action at this time.
0 commit comments