diff --git a/src/arviz_stats/numba/array.py b/src/arviz_stats/numba/array.py index c59719b9..5be067d7 100644 --- a/src/arviz_stats/numba/array.py +++ b/src/arviz_stats/numba/array.py @@ -66,14 +66,15 @@ def quantile(self, ary, quantile, axis=-1, method="linear", skipna=False, weight axes = axis if axes is not None: ary, axes = process_ary_axes(ary, axes) - axes = [(-1,), (0,), (0,)] else: ary = ary.ravel() + axes = [(-1,), (0,), (0,)] + scalar_q = np.ndim(quantile) == 0 # pylint: disable=no-value-for-parameter, unexpected-keyword-arg - result = _quantile_ufunc(ary, quantile, axes=axes) - if np.ndim(quantile) == 0: - return result + result = _quantile_ufunc(ary, np.atleast_1d(quantile), axes=axes) + if scalar_q: + return result.squeeze(0) return np.moveaxis(result, 0, -1) def _histogram(self, ary, bins=None, range=None, weights=None, density=None): # pylint: disable=redefined-builtin