|
28 | 28 | "kron", |
29 | 29 | "nunique", |
30 | 30 | "pad", |
| 31 | + "quantile", |
31 | 32 | "setdiff1d", |
32 | 33 | "sinc", |
33 | 34 | ] |
@@ -988,3 +989,166 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
988 | 989 | xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), |
989 | 990 | ) |
990 | 991 | return xp.sin(y) / y |
| 992 | + |
| 993 | + |
| 994 | +def quantile( |
| 995 | + x: Array, |
| 996 | + q: Array | float, |
| 997 | + /, |
| 998 | + *, |
| 999 | + axis: int | None = None, |
| 1000 | + keepdims: bool = False, |
| 1001 | + method: str = "linear", |
| 1002 | + xp: ModuleType | None = None, |
| 1003 | +) -> Array: |
| 1004 | + """See docstring in `array_api_extra._delegation.py`.""" |
| 1005 | + if xp is None: |
| 1006 | + xp = array_namespace(x, q) |
| 1007 | + |
| 1008 | + # Convert q to array if it's a scalar |
| 1009 | + q_is_scalar = isinstance(q, (int, float)) |
| 1010 | + if q_is_scalar: |
| 1011 | + q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) |
| 1012 | + |
| 1013 | + # Validate inputs |
| 1014 | + if not xp.isdtype(x.dtype, ("integral", "real floating")): |
| 1015 | + raise ValueError("`x` must have real dtype.") |
| 1016 | + if not xp.isdtype(q.dtype, "real floating"): |
| 1017 | + raise ValueError("`q` must have real floating dtype.") |
| 1018 | + |
| 1019 | + # Promote to common dtype |
| 1020 | + x = xp.astype(x, xp.float64) |
| 1021 | + q = xp.astype(q, xp.float64) |
| 1022 | + q = xp.asarray(q, device=_compat.device(x)) |
| 1023 | + |
| 1024 | + dtype = x.dtype |
| 1025 | + axis_none = axis is None |
| 1026 | + ndim = max(x.ndim, q.ndim) |
| 1027 | + |
| 1028 | + if axis_none: |
| 1029 | + x = xp.reshape(x, (-1,)) |
| 1030 | + q = xp.reshape(q, (-1,)) |
| 1031 | + axis = 0 |
| 1032 | + elif not isinstance(axis, int): |
| 1033 | + raise ValueError("`axis` must be an integer or None.") |
| 1034 | + elif axis >= ndim or axis < -ndim: |
| 1035 | + raise ValueError("`axis` is not compatible with the shapes of the inputs.") |
| 1036 | + else: |
| 1037 | + axis = int(axis) |
| 1038 | + |
| 1039 | + # Validate method |
| 1040 | + methods = { |
| 1041 | + 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', |
| 1042 | + 'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased', |
| 1043 | + 'normal_unbiased', 'weibull', 'harrell-davis' |
| 1044 | + } |
| 1045 | + if method not in methods: |
| 1046 | + raise ValueError(f"`method` must be one of {methods}") |
| 1047 | + |
| 1048 | + # Handle keepdims parameter |
| 1049 | + if keepdims not in {None, True, False}: |
| 1050 | + raise ValueError("If specified, `keepdims` must be True or False.") |
| 1051 | + |
| 1052 | + # Handle empty arrays |
| 1053 | + if x.shape[axis] == 0: |
| 1054 | + shape = list(x.shape) |
| 1055 | + shape[axis] = 1 |
| 1056 | + x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) |
| 1057 | + |
| 1058 | + # Sort the data |
| 1059 | + y = xp.sort(x, axis=axis) |
| 1060 | + |
| 1061 | + # Move axis to the end for easier processing |
| 1062 | + y = xp.moveaxis(y, axis, -1) |
| 1063 | + if not (q_is_scalar or q.ndim == 0): |
| 1064 | + q = xp.moveaxis(q, axis, -1) |
| 1065 | + |
| 1066 | + # Get the number of elements along the axis |
| 1067 | + n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) |
| 1068 | + |
| 1069 | + # Apply quantile calculation based on method |
| 1070 | + if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', |
| 1071 | + 'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased', |
| 1072 | + 'normal_unbiased', 'weibull'}: |
| 1073 | + res = _quantile_hf(y, q, n, method, xp) |
| 1074 | + elif method == 'harrell-davis': |
| 1075 | + res = _quantile_hd(y, q, n, xp) |
| 1076 | + else: |
| 1077 | + raise ValueError(f"Unknown method: {method}") |
| 1078 | + |
| 1079 | + # Handle NaN output for invalid q values |
| 1080 | + p_mask = (q > 1) | (q < 0) | xp.isnan(q) |
| 1081 | + if xp.any(p_mask): |
| 1082 | + res = xp.asarray(res, copy=True) |
| 1083 | + res = at(res, p_mask).set(xp.nan) |
| 1084 | + |
| 1085 | + # Reshape per axis/keepdims |
| 1086 | + if axis_none and keepdims: |
| 1087 | + shape = (1,) * (ndim - 1) + res.shape |
| 1088 | + res = xp.reshape(res, shape) |
| 1089 | + axis = -1 |
| 1090 | + |
| 1091 | + # Move axis back to original position |
| 1092 | + res = xp.moveaxis(res, -1, axis) |
| 1093 | + |
| 1094 | + # Handle keepdims |
| 1095 | + if not keepdims and res.shape[axis] == 1: |
| 1096 | + res = xp.squeeze(res, axis=axis) |
| 1097 | + |
| 1098 | + # For scalar q, ensure we return a scalar result |
| 1099 | + if q_is_scalar: |
| 1100 | + if hasattr(res, 'shape') and res.shape != (): |
| 1101 | + res = res[()] |
| 1102 | + |
| 1103 | + return res |
| 1104 | + |
| 1105 | + |
| 1106 | +def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> Array: |
| 1107 | + """Helper function for Hyndman-Fan quantile methods.""" |
| 1108 | + ms = { |
| 1109 | + 'inverted_cdf': 0, |
| 1110 | + 'averaged_inverted_cdf': 0, |
| 1111 | + 'closest_observation': -0.5, |
| 1112 | + 'interpolated_inverted_cdf': 0, |
| 1113 | + 'hazen': 0.5, |
| 1114 | + 'weibull': p, |
| 1115 | + 'linear': 1 - p, |
| 1116 | + 'median_unbiased': p/3 + 1/3, |
| 1117 | + 'normal_unbiased': p/4 + 3/8 |
| 1118 | + } |
| 1119 | + m = ms[method] |
| 1120 | + |
| 1121 | + jg = p * n + m - 1 |
| 1122 | + j = xp.astype(jg // 1, xp.int64) # Convert to integer |
| 1123 | + g = jg % 1 |
| 1124 | + |
| 1125 | + if method == 'inverted_cdf': |
| 1126 | + g = xp.astype((g > 0), jg.dtype) |
| 1127 | + elif method == 'averaged_inverted_cdf': |
| 1128 | + g = (1 + xp.astype((g > 0), jg.dtype)) / 2 |
| 1129 | + elif method == 'closest_observation': |
| 1130 | + g = (1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)) |
| 1131 | + if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation'}: |
| 1132 | + g = xp.asarray(g) |
| 1133 | + g = at(g, jg < 0).set(0) |
| 1134 | + g = at(g, j < 0).set(0) |
| 1135 | + j = xp.clip(j, 0, n - 1) |
| 1136 | + jp1 = xp.clip(j + 1, 0, n - 1) |
| 1137 | + |
| 1138 | + # Broadcast indices to match y shape except for the last axis |
| 1139 | + if y.ndim > 1: |
| 1140 | + # Create broadcast shape for indices |
| 1141 | + broadcast_shape = list(y.shape[:-1]) + [1] |
| 1142 | + j = xp.broadcast_to(j, broadcast_shape) |
| 1143 | + jp1 = xp.broadcast_to(jp1, broadcast_shape) |
| 1144 | + g = xp.broadcast_to(g, broadcast_shape) |
| 1145 | + |
| 1146 | + return ((1 - g) * xp.take_along_axis(y, j, axis=-1) + |
| 1147 | + g * xp.take_along_axis(y, jp1, axis=-1)) |
| 1148 | + |
| 1149 | + |
| 1150 | +def _quantile_hd(y: Array, p: Array, n: Array, xp: ModuleType) -> Array: |
| 1151 | + """Helper function for Harrell-Davis quantile method.""" |
| 1152 | + # For now, implement a simplified version that falls back to linear method |
| 1153 | + # since betainc is not available in the array API standard |
| 1154 | + return _quantile_hf(y, p, n, "linear", xp) |
0 commit comments