@@ -1056,122 +1056,6 @@ def key_function(out_key):
1056
1056
1057
1057
1058
1058
def reduction (
1059
- x : "Array" ,
1060
- func ,
1061
- combine_func = None ,
1062
- aggregate_func = None ,
1063
- axis = None ,
1064
- intermediate_dtype = None ,
1065
- dtype = None ,
1066
- keepdims = False ,
1067
- use_new_impl = True ,
1068
- split_every = None ,
1069
- extra_func_kwargs = None ,
1070
- ) -> "Array" :
1071
- """Apply a function to reduce an array along one or more axes."""
1072
- if use_new_impl :
1073
- return reduction_new (
1074
- x ,
1075
- func ,
1076
- combine_func = combine_func ,
1077
- aggregate_func = aggregate_func ,
1078
- axis = axis ,
1079
- intermediate_dtype = intermediate_dtype ,
1080
- dtype = dtype ,
1081
- keepdims = keepdims ,
1082
- split_every = split_every ,
1083
- extra_func_kwargs = extra_func_kwargs ,
1084
- )
1085
- if combine_func is None :
1086
- combine_func = func
1087
- if axis is None :
1088
- axis = tuple (range (x .ndim ))
1089
- if isinstance (axis , Integral ):
1090
- axis = (axis ,)
1091
- axis = validate_axis (axis , x .ndim )
1092
- if intermediate_dtype is None :
1093
- intermediate_dtype = dtype
1094
-
1095
- inds = tuple (range (x .ndim ))
1096
-
1097
- result = x
1098
- allowed_mem = x .spec .allowed_mem
1099
- max_mem = allowed_mem - x .spec .reserved_mem
1100
-
1101
- # reduce initial chunks
1102
- args = (result , inds )
1103
- adjust_chunks = {
1104
- i : (1 ,) * len (c ) if i in axis else c for i , c in enumerate (result .chunks )
1105
- }
1106
- result = blockwise (
1107
- func ,
1108
- inds ,
1109
- * args ,
1110
- axis = axis ,
1111
- keepdims = True ,
1112
- dtype = intermediate_dtype ,
1113
- adjust_chunks = adjust_chunks ,
1114
- extra_func_kwargs = extra_func_kwargs ,
1115
- )
1116
-
1117
- # merge/reduce along axis in multiple rounds until there's a single block in each reduction axis
1118
- while any (n > 1 for i , n in enumerate (result .numblocks ) if i in axis ):
1119
- # merge along axis
1120
- target_chunks = list (result .chunksize )
1121
- chunk_mem = array_memory (intermediate_dtype , result .chunksize )
1122
- for i , s in enumerate (result .shape ):
1123
- if i in axis :
1124
- assert result .chunksize [i ] == 1 # result of reduction
1125
- if len (axis ) > 1 :
1126
- # multi-axis: don't exceed original chunksize in any reduction axis
1127
- # TODO: improve to use up to max_mem
1128
- target_chunks [i ] = min (s , x .chunksize [i ])
1129
- else :
1130
- # single axis: see how many result chunks fit in max_mem
1131
- # factor of 4 is memory for {compressed, uncompressed} x {input, output}
1132
- target_chunk_size = (max_mem - chunk_mem ) // (chunk_mem * 4 )
1133
- if target_chunk_size <= 1 :
1134
- raise ValueError (
1135
- f"Not enough memory for reduction. Increase allowed_mem ({ allowed_mem } ) or decrease chunk size"
1136
- )
1137
- target_chunks [i ] = min (s , target_chunk_size )
1138
- _target_chunks = tuple (target_chunks )
1139
- result = merge_chunks (result , _target_chunks )
1140
-
1141
- # reduce chunks (if any axis chunksize is > 1)
1142
- if any (s > 1 for i , s in enumerate (result .chunksize ) if i in axis ):
1143
- args = (result , inds )
1144
- adjust_chunks = {
1145
- i : (1 ,) * len (c ) if i in axis else c
1146
- for i , c in enumerate (result .chunks )
1147
- }
1148
- result = blockwise (
1149
- combine_func ,
1150
- inds ,
1151
- * args ,
1152
- axis = axis ,
1153
- keepdims = True ,
1154
- dtype = intermediate_dtype ,
1155
- adjust_chunks = adjust_chunks ,
1156
- extra_func_kwargs = extra_func_kwargs ,
1157
- )
1158
-
1159
- if aggregate_func is not None :
1160
- result = map_blocks (aggregate_func , result , dtype = dtype )
1161
-
1162
- if not keepdims :
1163
- axis_to_squeeze = tuple (i for i in axis if result .shape [i ] == 1 )
1164
- if len (axis_to_squeeze ) > 0 :
1165
- result = squeeze (result , axis_to_squeeze )
1166
-
1167
- from cubed .array_api import astype
1168
-
1169
- result = astype (result , dtype , copy = False )
1170
-
1171
- return result
1172
-
1173
-
1174
- def reduction_new (
1175
1059
x : "Array" ,
1176
1060
func ,
1177
1061
combine_func = None ,
@@ -1426,9 +1310,7 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
1426
1310
return result
1427
1311
1428
1312
1429
- def arg_reduction (
1430
- x , / , arg_func , axis = None , * , keepdims = False , use_new_impl = True , split_every = None
1431
- ):
1313
+ def arg_reduction (x , / , arg_func , axis = None , * , keepdims = False , split_every = None ):
1432
1314
"""A reduction that returns the array indexes, not the values."""
1433
1315
dtype = nxp .int64 # index data type
1434
1316
intermediate_dtype = [("i" , dtype ), ("v" , x .dtype )]
@@ -1454,7 +1336,6 @@ def arg_reduction(
1454
1336
intermediate_dtype = intermediate_dtype ,
1455
1337
dtype = dtype ,
1456
1338
keepdims = keepdims ,
1457
- use_new_impl = use_new_impl ,
1458
1339
split_every = split_every ,
1459
1340
)
1460
1341
0 commit comments