Skip to content

Commit 899a6e3

Browse files
Add compatibility with old dataset name convention
1 parent 72bf515 commit 899a6e3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

wf_psf/metrics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def compute_shape_metrics(
416416
predictions = tf_semiparam_field.predict(x=pred_inputs, batch_size=batch_size)
417417

418418
# GT data preparation
419-
if dataset_dict is None or 'super_res_stars' not in dataset_dict:
419+
if dataset_dict is None or 'super_res_stars' not in dataset_dict or 'SR_stars' not in dataset_dict:
420420
print('Generating GT super resolved stars from the GT model.')
421421
# Change interpolation parameters for the GT simPSF
422422
interp_pts_per_bin = simPSF_np.interp_pts_per_bin
@@ -438,7 +438,10 @@ def compute_shape_metrics(
438438

439439
else:
440440
print('Using super resolved stars from dataset.')
441-
GT_predictions = dataset_dict['super_res_stars']
441+
if 'super_res_stars' in dataset_dict:
442+
GT_predictions = dataset_dict['super_res_stars']
443+
elif 'SR_stars' in dataset_dict:
444+
GT_predictions = dataset_dict['SR_stars']
442445

443446
# Calculate residuals
444447
residuals = np.sqrt(np.mean((GT_predictions - predictions)**2, axis=(1, 2)))

0 commit comments

Comments
 (0)