|
28 | 28 | "kron", |
29 | 29 | "nunique", |
30 | 30 | "pad", |
31 | | - "quantile", |
32 | 31 | "setdiff1d", |
33 | 32 | "sinc", |
34 | 33 | ] |
@@ -989,151 +988,3 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
989 | 988 | xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), |
990 | 989 | ) |
991 | 990 | return xp.sin(y) / y |
992 | | - |
993 | | - |
994 | | -def quantile( |
995 | | - x: Array, |
996 | | - q: Array | float, |
997 | | - /, |
998 | | - *, |
999 | | - axis: int | None = None, |
1000 | | - keepdims: bool = False, |
1001 | | - method: str = "linear", |
1002 | | - xp: ModuleType | None = None, |
1003 | | -) -> Array: # numpydoc ignore=PR01,RT01 |
1004 | | - """See docstring in `array_api_extra._delegation.py`.""" |
1005 | | - if xp is None: |
1006 | | - xp = array_namespace(x, q) |
1007 | | - |
1008 | | - q_is_scalar = isinstance(q, int | float) |
1009 | | - if q_is_scalar: |
1010 | | - q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) |
1011 | | - |
1012 | | - if not xp.isdtype(x.dtype, ("integral", "real floating")): |
1013 | | - raise ValueError("`x` must have real dtype.") # noqa: EM101 |
1014 | | - if not xp.isdtype(q.dtype, "real floating"): |
1015 | | - raise ValueError("`q` must have real floating dtype.") # noqa: EM101 |
1016 | | - |
1017 | | - # Promote to common dtype |
1018 | | - x = xp.astype(x, xp.float64) |
1019 | | - q = xp.astype(q, xp.float64) |
1020 | | - q = xp.asarray(q, device=_compat.device(x)) |
1021 | | - |
1022 | | - dtype = x.dtype |
1023 | | - axis_none = axis is None |
1024 | | - ndim = max(x.ndim, q.ndim) |
1025 | | - |
1026 | | - if axis_none: |
1027 | | - x = xp.reshape(x, (-1,)) |
1028 | | - q = xp.reshape(q, (-1,)) |
1029 | | - axis = 0 |
1030 | | - elif not isinstance(axis, int): |
1031 | | - raise ValueError("`axis` must be an integer or None.") # noqa: EM101 |
1032 | | - elif axis >= ndim or axis < -ndim: |
1033 | | - raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101 |
1034 | | - else: |
1035 | | - axis = int(axis) |
1036 | | - |
1037 | | - if keepdims not in {None, True, False}: |
1038 | | - raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101 |
1039 | | - |
1040 | | - if x.shape[axis] == 0: |
1041 | | - shape = list(x.shape) |
1042 | | - shape[axis] = 1 |
1043 | | - x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) |
1044 | | - |
1045 | | - y = xp.sort(x, axis=axis) |
1046 | | - |
1047 | | - # Move axis to the end for easier processing |
1048 | | - y = xp.moveaxis(y, axis, -1) |
1049 | | - if not (q_is_scalar or q.ndim == 0): |
1050 | | - q = xp.moveaxis(q, axis, -1) |
1051 | | - |
1052 | | - n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) |
1053 | | - |
1054 | | - if method in { # pylint: disable=duplicate-code |
1055 | | - "inverted_cdf", |
1056 | | - "averaged_inverted_cdf", |
1057 | | - "closest_observation", |
1058 | | - "hazen", |
1059 | | - "interpolated_inverted_cdf", |
1060 | | - "linear", |
1061 | | - "median_unbiased", |
1062 | | - "normal_unbiased", |
1063 | | - "weibull", |
1064 | | - }: |
1065 | | - res = _quantile_hf(y, q, n, method, xp) |
1066 | | - else: |
1067 | | - raise ValueError(f"Unknown method: {method}") # noqa: EM102 |
1068 | | - |
1069 | | - # Handle NaN output for invalid q values |
1070 | | - p_mask = (q > 1) | (q < 0) | xp.isnan(q) |
1071 | | - if xp.any(p_mask): |
1072 | | - res = xp.asarray(res, copy=True) |
1073 | | - res = at(res, p_mask).set(xp.nan) |
1074 | | - |
1075 | | - # Reshape per axis/keepdims |
1076 | | - if axis_none and keepdims: |
1077 | | - shape = (1,) * (ndim - 1) + res.shape |
1078 | | - res = xp.reshape(res, shape) |
1079 | | - axis = -1 |
1080 | | - |
1081 | | - # Move axis back to original position |
1082 | | - res = xp.moveaxis(res, -1, axis) |
1083 | | - |
1084 | | - # Handle keepdims |
1085 | | - if not keepdims and res.shape[axis] == 1: |
1086 | | - res = xp.squeeze(res, axis=axis) |
1087 | | - |
1088 | | - # For scalar q, ensure we return a scalar result |
1089 | | - if q_is_scalar and hasattr(res, "shape") and res.shape != (): |
1090 | | - res = res[()] |
1091 | | - |
1092 | | - return res |
1093 | | - |
1094 | | - |
1095 | | -def _quantile_hf( |
1096 | | - y: Array, p: Array, n: Array, method: str, xp: ModuleType |
1097 | | -) -> Array: # numpydoc ignore=PR01,RT01 |
1098 | | - """Helper function for Hyndman-Fan quantile method.""" |
1099 | | - ms = { |
1100 | | - "inverted_cdf": 0, |
1101 | | - "averaged_inverted_cdf": 0, |
1102 | | - "closest_observation": -0.5, |
1103 | | - "interpolated_inverted_cdf": 0, |
1104 | | - "hazen": 0.5, |
1105 | | - "weibull": p, |
1106 | | - "linear": 1 - p, |
1107 | | - "median_unbiased": p / 3 + 1 / 3, |
1108 | | - "normal_unbiased": p / 4 + 3 / 8, |
1109 | | - } |
1110 | | - m = ms[method] |
1111 | | - |
1112 | | - jg = p * n + m - 1 |
1113 | | - j = xp.astype(jg // 1, xp.int64) # Convert to integer |
1114 | | - g = jg % 1 |
1115 | | - |
1116 | | - if method == "inverted_cdf": |
1117 | | - g = xp.astype((g > 0), jg.dtype) |
1118 | | - elif method == "averaged_inverted_cdf": |
1119 | | - g = (1 + xp.astype((g > 0), jg.dtype)) / 2 |
1120 | | - elif method == "closest_observation": |
1121 | | - g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype) |
1122 | | - if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}: |
1123 | | - g = xp.asarray(g) |
1124 | | - g = at(g, jg < 0).set(0) |
1125 | | - g = at(g, j < 0).set(0) |
1126 | | - j = xp.clip(j, 0, n - 1) |
1127 | | - jp1 = xp.clip(j + 1, 0, n - 1) |
1128 | | - |
1129 | | - # Broadcast indices to match y shape except for the last axis |
1130 | | - if y.ndim > 1: |
1131 | | - # Create broadcast shape for indices |
1132 | | - broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005 |
1133 | | - j = xp.broadcast_to(j, broadcast_shape) |
1134 | | - jp1 = xp.broadcast_to(jp1, broadcast_shape) |
1135 | | - g = xp.broadcast_to(g, broadcast_shape) |
1136 | | - |
1137 | | - return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( |
1138 | | - y, jp1, axis=-1 |
1139 | | - ) |
0 commit comments