|
35 | 35 | "_select_obs_by_indices", |
36 | 36 | "_select_obs_by_coords", |
37 | 37 | "_prepare_full_arrays", |
| 38 | + "_validate_crps_input", |
38 | 39 | ] |
39 | 40 |
|
40 | 41 | LooInputs = namedtuple( |
@@ -998,3 +999,59 @@ def _reconstruct_upars(upars_new_values, props): |
998 | 999 | dims=props["dims"], |
999 | 1000 | coords=props["coords"], |
1000 | 1001 | ) |
| 1002 | + |
| 1003 | + |
| 1004 | +def _has_nan_slice(da, dim): |
| 1005 | + """Check if DataArray has NaN slices along a dimension.""" |
| 1006 | + other = tuple(dd for dd in da.dims if dd != dim) |
| 1007 | + return bool(da.isnull().all(dim=other).any()) if other else bool(da.isnull().any()) |
| 1008 | + |
| 1009 | + |
| 1010 | +def _validate_crps_input(y_pred, y_obs, log_likelihood, *, sample_dims, obs_dims): |
| 1011 | + """Shape and dimension checks.""" |
| 1012 | + missing_sample = [d for d in sample_dims if d not in y_pred.dims] |
| 1013 | + if missing_sample: |
| 1014 | + raise ValueError(f"y_pred must include sample dimension '{missing_sample[0]}'") |
| 1015 | + |
| 1016 | + missing_obs = [ |
| 1017 | + d |
| 1018 | + for d in obs_dims |
| 1019 | + if (d not in y_pred.dims or d not in y_obs.dims or d not in log_likelihood.dims) |
| 1020 | + ] |
| 1021 | + if missing_obs: |
| 1022 | + raise ValueError(f"Missing observation dimension '{missing_obs[0]}' in inputs") |
| 1023 | + |
| 1024 | + ypred_obs_aligned, yobs_aligned = xr.align(y_pred, y_obs, join="inner") |
| 1025 | + obs_size_mismatch = [ |
| 1026 | + d |
| 1027 | + for d in obs_dims |
| 1028 | + if ( |
| 1029 | + ypred_obs_aligned.sizes[d] != y_pred.sizes[d] or yobs_aligned.sizes[d] != y_obs.sizes[d] |
| 1030 | + ) |
| 1031 | + ] |
| 1032 | + if obs_size_mismatch: |
| 1033 | + raise ValueError(f"Size mismatch in observation dim '{obs_size_mismatch[0]}'") |
| 1034 | + |
| 1035 | + nan_padded_obs = [d for d in obs_dims if _has_nan_slice(y_obs, d)] |
| 1036 | + if nan_padded_obs: |
| 1037 | + raise ValueError(f"Size mismatch in observation dim '{nan_padded_obs[0]}'") |
| 1038 | + |
| 1039 | + ypred_ll_aligned, ll_aligned = xr.align(y_pred, log_likelihood, join="inner") |
| 1040 | + dims_to_check = (*sample_dims, *obs_dims) |
| 1041 | + ll_size_mismatch = [ |
| 1042 | + d |
| 1043 | + for d in dims_to_check |
| 1044 | + if ( |
| 1045 | + ypred_ll_aligned.sizes[d] != y_pred.sizes[d] |
| 1046 | + or ll_aligned.sizes[d] != log_likelihood.sizes[d] |
| 1047 | + ) |
| 1048 | + ] |
| 1049 | + if ll_size_mismatch: |
| 1050 | + d0 = ll_size_mismatch[0] |
| 1051 | + if d0 in sample_dims: |
| 1052 | + raise ValueError(f"Size mismatch in sample dimension '{d0}'") |
| 1053 | + raise ValueError(f"Size mismatch in observation dim '{d0}'") |
| 1054 | + |
| 1055 | + nan_padded_sample = [d for d in sample_dims if _has_nan_slice(log_likelihood, d)] |
| 1056 | + if nan_padded_sample: |
| 1057 | + raise ValueError(f"Size mismatch in sample dimension '{nan_padded_sample[0]}'") |
0 commit comments