@@ -1004,147 +1004,162 @@ def quantile(
10041004 """See docstring in `array_api_extra._delegation.py`."""
10051005 if xp is None :
10061006 xp = array_namespace (x , q )
1007-
1007+
10081008 # Convert q to array if it's a scalar
1009- q_is_scalar = isinstance (q , ( int , float ) )
1009+ q_is_scalar = isinstance (q , int | float )
10101010 if q_is_scalar :
10111011 q = xp .asarray (q , dtype = xp .float64 , device = _compat .device (x ))
1012-
1012+
10131013 # Validate inputs
10141014 if not xp .isdtype (x .dtype , ("integral" , "real floating" )):
1015- raise ValueError ("`x` must have real dtype." )
1015+ raise ValueError ("`x` must have real dtype." ) # noqa: EM101
10161016 if not xp .isdtype (q .dtype , "real floating" ):
1017- raise ValueError ("`q` must have real floating dtype." )
1018-
1017+ raise ValueError ("`q` must have real floating dtype." ) # noqa: EM101
1018+
10191019 # Promote to common dtype
10201020 x = xp .astype (x , xp .float64 )
10211021 q = xp .astype (q , xp .float64 )
10221022 q = xp .asarray (q , device = _compat .device (x ))
1023-
1023+
10241024 dtype = x .dtype
10251025 axis_none = axis is None
10261026 ndim = max (x .ndim , q .ndim )
1027-
1027+
10281028 if axis_none :
10291029 x = xp .reshape (x , (- 1 ,))
10301030 q = xp .reshape (q , (- 1 ,))
10311031 axis = 0
10321032 elif not isinstance (axis , int ):
1033- raise ValueError ("`axis` must be an integer or None." )
1033+ raise ValueError ("`axis` must be an integer or None." ) # noqa: EM101
10341034 elif axis >= ndim or axis < - ndim :
1035- raise ValueError ("`axis` is not compatible with the shapes of the inputs." )
1035+ raise ValueError ("`axis` is not compatible with the shapes of the inputs." ) # noqa: EM101
10361036 else :
10371037 axis = int (axis )
1038-
1038+
10391039 # Validate method
10401040 methods = {
1041- 'inverted_cdf' , 'averaged_inverted_cdf' , 'closest_observation' ,
1042- 'hazen' , 'interpolated_inverted_cdf' , 'linear' , 'median_unbiased' ,
1043- 'normal_unbiased' , 'weibull' , 'harrell-davis'
1041+ "inverted_cdf" ,
1042+ "averaged_inverted_cdf" ,
1043+ "closest_observation" ,
1044+ "hazen" ,
1045+ "interpolated_inverted_cdf" ,
1046+ "linear" ,
1047+ "median_unbiased" ,
1048+ "normal_unbiased" ,
1049+ "weibull" ,
1050+ "harrell-davis" ,
10441051 }
10451052 if method not in methods :
1046- raise ValueError (f"`method` must be one of { methods } " )
1047-
1053+ raise ValueError (f"`method` must be one of { methods } " ) # noqa: EM102
1054+
10481055 # Handle keepdims parameter
10491056 if keepdims not in {None , True , False }:
1050- raise ValueError ("If specified, `keepdims` must be True or False." )
1051-
1057+ raise ValueError ("If specified, `keepdims` must be True or False." ) # noqa: EM101
1058+
10521059 # Handle empty arrays
10531060 if x .shape [axis ] == 0 :
10541061 shape = list (x .shape )
10551062 shape [axis ] = 1
10561063 x = xp .full (shape , xp .nan , dtype = dtype , device = _compat .device (x ))
1057-
1064+
10581065 # Sort the data
10591066 y = xp .sort (x , axis = axis )
1060-
1067+
10611068 # Move axis to the end for easier processing
10621069 y = xp .moveaxis (y , axis , - 1 )
10631070 if not (q_is_scalar or q .ndim == 0 ):
10641071 q = xp .moveaxis (q , axis , - 1 )
1065-
1072+
10661073 # Get the number of elements along the axis
10671074 n = xp .asarray (y .shape [- 1 ], dtype = dtype , device = _compat .device (y ))
1068-
1075+
10691076 # Apply quantile calculation based on method
1070- if method in {'inverted_cdf' , 'averaged_inverted_cdf' , 'closest_observation' ,
1071- 'hazen' , 'interpolated_inverted_cdf' , 'linear' , 'median_unbiased' ,
1072- 'normal_unbiased' , 'weibull' }:
1077+ if method in {
1078+ "inverted_cdf" ,
1079+ "averaged_inverted_cdf" ,
1080+ "closest_observation" ,
1081+ "hazen" ,
1082+ "interpolated_inverted_cdf" ,
1083+ "linear" ,
1084+ "median_unbiased" ,
1085+ "normal_unbiased" ,
1086+ "weibull" ,
1087+ }:
10731088 res = _quantile_hf (y , q , n , method , xp )
1074- elif method == ' harrell-davis' :
1089+ elif method == " harrell-davis" :
10751090 res = _quantile_hd (y , q , n , xp )
10761091 else :
1077- raise ValueError (f"Unknown method: { method } " )
1078-
1092+ raise ValueError (f"Unknown method: { method } " ) # noqa: EM102
1093+
10791094 # Handle NaN output for invalid q values
10801095 p_mask = (q > 1 ) | (q < 0 ) | xp .isnan (q )
10811096 if xp .any (p_mask ):
10821097 res = xp .asarray (res , copy = True )
10831098 res = at (res , p_mask ).set (xp .nan )
1084-
1099+
10851100 # Reshape per axis/keepdims
10861101 if axis_none and keepdims :
10871102 shape = (1 ,) * (ndim - 1 ) + res .shape
10881103 res = xp .reshape (res , shape )
10891104 axis = - 1
1090-
1105+
10911106 # Move axis back to original position
10921107 res = xp .moveaxis (res , - 1 , axis )
1093-
1108+
10941109 # Handle keepdims
10951110 if not keepdims and res .shape [axis ] == 1 :
10961111 res = xp .squeeze (res , axis = axis )
1097-
1112+
10981113 # For scalar q, ensure we return a scalar result
1099- if q_is_scalar :
1100- if hasattr (res , 'shape' ) and res .shape != ():
1101- res = res [()]
1102-
1114+ if q_is_scalar and hasattr (res , "shape" ) and res .shape != ():
1115+ res = res [()]
1116+
11031117 return res
11041118
11051119
11061120def _quantile_hf (y : Array , p : Array , n : Array , method : str , xp : ModuleType ) -> Array :
11071121 """Helper function for Hyndman-Fan quantile methods."""
11081122 ms = {
1109- ' inverted_cdf' : 0 ,
1110- ' averaged_inverted_cdf' : 0 ,
1111- ' closest_observation' : - 0.5 ,
1112- ' interpolated_inverted_cdf' : 0 ,
1113- ' hazen' : 0.5 ,
1114- ' weibull' : p ,
1115- ' linear' : 1 - p ,
1116- ' median_unbiased' : p / 3 + 1 / 3 ,
1117- ' normal_unbiased' : p / 4 + 3 / 8
1123+ " inverted_cdf" : 0 ,
1124+ " averaged_inverted_cdf" : 0 ,
1125+ " closest_observation" : - 0.5 ,
1126+ " interpolated_inverted_cdf" : 0 ,
1127+ " hazen" : 0.5 ,
1128+ " weibull" : p ,
1129+ " linear" : 1 - p ,
1130+ " median_unbiased" : p / 3 + 1 / 3 ,
1131+ " normal_unbiased" : p / 4 + 3 / 8 ,
11181132 }
11191133 m = ms [method ]
1120-
1134+
11211135 jg = p * n + m - 1
11221136 j = xp .astype (jg // 1 , xp .int64 ) # Convert to integer
11231137 g = jg % 1
1124-
1125- if method == ' inverted_cdf' :
1138+
1139+ if method == " inverted_cdf" :
11261140 g = xp .astype ((g > 0 ), jg .dtype )
1127- elif method == ' averaged_inverted_cdf' :
1141+ elif method == " averaged_inverted_cdf" :
11281142 g = (1 + xp .astype ((g > 0 ), jg .dtype )) / 2
1129- elif method == ' closest_observation' :
1130- g = ( 1 - xp .astype ((g == 0 ) & (j % 2 == 1 ), jg .dtype ) )
1131- if method in {' inverted_cdf' , ' averaged_inverted_cdf' , ' closest_observation' }:
1143+ elif method == " closest_observation" :
1144+ g = 1 - xp .astype ((g == 0 ) & (j % 2 == 1 ), jg .dtype )
1145+ if method in {" inverted_cdf" , " averaged_inverted_cdf" , " closest_observation" }:
11321146 g = xp .asarray (g )
11331147 g = at (g , jg < 0 ).set (0 )
11341148 g = at (g , j < 0 ).set (0 )
11351149 j = xp .clip (j , 0 , n - 1 )
11361150 jp1 = xp .clip (j + 1 , 0 , n - 1 )
1137-
1151+
11381152 # Broadcast indices to match y shape except for the last axis
11391153 if y .ndim > 1 :
11401154 # Create broadcast shape for indices
1141- broadcast_shape = list (y .shape [:- 1 ]) + [1 ]
1155+ broadcast_shape = list (y .shape [:- 1 ]) + [1 ] # noqa: RUF005
11421156 j = xp .broadcast_to (j , broadcast_shape )
11431157 jp1 = xp .broadcast_to (jp1 , broadcast_shape )
11441158 g = xp .broadcast_to (g , broadcast_shape )
1145-
1146- return ((1 - g ) * xp .take_along_axis (y , j , axis = - 1 ) +
1147- g * xp .take_along_axis (y , jp1 , axis = - 1 ))
1159+
1160+ return (1 - g ) * xp .take_along_axis (y , j , axis = - 1 ) + g * xp .take_along_axis (
1161+ y , jp1 , axis = - 1
1162+ )
11481163
11491164
11501165def _quantile_hd (y : Array , p : Array , n : Array , xp : ModuleType ) -> Array :
0 commit comments