Skip to content

Commit 8729a2f

Browse files
committed
Opinionated cleanup
1 parent 1154ad3 commit 8729a2f

File tree

5 files changed

+58
-62
lines changed

5 files changed

+58
-62
lines changed

src/arithmetics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ function count(
283283
)
284284
mapreduce(
285285
x -> x ? one(typeof(init)) : zero(typeof(init)), +, src, backend;
286-
init=init,
286+
init,
287287
neutral=zero(typeof(init)),
288288
kwargs...
289289
)

src/reduce/mapreduce_1d.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function mapreduce_1d(
117117
len == 1 && return @allowscalar f(src[1])
118118
if len < switch_below
119119
h_src = Vector(src)
120-
return Base.mapreduce(f, op, h_src, init=init)
120+
return Base.mapreduce(f, op, h_src; init)
121121
end
122122

123123
# Each thread will handle two elements
@@ -147,7 +147,7 @@ function mapreduce_1d(
147147
len = blocks
148148
if len < switch_below
149149
h_src = Vector(@view(dst[1:len]))
150-
return Base.reduce(op, h_src, init=init)
150+
return Base.reduce(op, h_src; init)
151151
end
152152

153153
# Now all src elements have been passed through f; just do final reduction, no map needed
@@ -163,7 +163,7 @@ function mapreduce_1d(
163163

164164
if len < switch_below
165165
h_src = Vector(@view(p2[1:len]))
166-
return Base.reduce(op, h_src, init=init)
166+
return Base.reduce(op, h_src; init)
167167
end
168168

169169
p1, p2 = p2, p1

src/reduce/reduce.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,24 @@ function _mapreduce_impl(
212212
if isnothing(dims)
213213
return mapreduce_1d(
214214
f, op, src, backend;
215-
init=init,
216-
neutral=neutral,
217-
block_size=block_size,
218-
temp=temp,
219-
switch_below=switch_below,
215+
init, neutral,
216+
block_size, temp,
217+
switch_below,
220218
)
221219
else
222220
return mapreduce_nd(
223221
f, op, src, backend;
224-
init=init,
225-
neutral=neutral,
226-
dims=dims,
227-
block_size=block_size,
228-
temp=temp,
222+
init, neutral,
223+
dims, block_size,
224+
temp,
229225
)
230226
end
231227
else
232228
if isnothing(dims)
233229
num_elems = length(src)
234230
num_tasks = min(max_tasks, num_elems ÷ min_elems)
235231
if num_tasks <= 1
236-
return Base.mapreduce(f, op, src; init=init)
232+
return Base.mapreduce(f, op, src; init)
237233
end
238234
return op(init, OMT.tmapreduce(
239235
f, op, src; init=neutral,
@@ -243,7 +239,7 @@ function _mapreduce_impl(
243239
))
244240
else
245241
# FIXME: waiting on OhMyThreads.jl for n-dimensional reduction
246-
return Base.mapreduce(f, op, src; init=init, dims=dims)
242+
return Base.mapreduce(f, op, src; init, dims)
247243
end
248244
end
249245
end

test/accumulate.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
y = similar(x)
6666
init = rand(-1000:1000)
6767
AK.accumulate!(+, y, x; init=Int32(init))
68-
@test all(Array(y) .== accumulate(+, Array(x), init=init))
68+
@test all(Array(y) .== accumulate(+, Array(x); init))
6969
end
7070

7171
# Exclusive scan
@@ -103,10 +103,10 @@ end
103103
for ksize in 0:3
104104
sh = rand(Int32(1):Int32(100), isize, jsize, ksize)
105105
s = array_from_host(sh)
106-
d = AK.accumulate(+, s; init=Int32(0), dims=dims)
106+
d = AK.accumulate(+, s; init=Int32(0), dims)
107107

108108
dh = Array(d)
109-
dhres = accumulate(+, sh, init=Int32(0), dims=dims)
109+
dhres = accumulate(+, sh; init=Int32(0), dims)
110110
@test dh == dhres
111111
@test eltype(dh) == eltype(dhres)
112112
end
@@ -123,9 +123,9 @@ end
123123
vh = rand(Int32(1):Int32(100), n1, n2, n3)
124124
v = array_from_host(vh)
125125

126-
s = AK.accumulate(+, v; init=Int32(0), dims=dims)
126+
s = AK.accumulate(+, v; init=Int32(0), dims)
127127
sh = Array(s)
128-
@test sh == accumulate(+, vh, init=Int32(0), dims=dims)
128+
@test sh == accumulate(+, vh; init=Int32(0), dims)
129129
end
130130
end
131131

@@ -137,9 +137,9 @@ end
137137
vh = rand(UInt32(1):UInt32(100), n1, n2, n3)
138138
v = array_from_host(vh)
139139

140-
s = AK.accumulate(+, v; init=UInt32(0), dims=dims)
140+
s = AK.accumulate(+, v; init=UInt32(0), dims)
141141
sh = Array(s)
142-
@test sh == accumulate(+, vh, init=UInt32(0), dims=dims)
142+
@test sh == accumulate(+, vh; init=UInt32(0), dims)
143143
end
144144
end
145145

@@ -151,9 +151,9 @@ end
151151
vh = rand(Float32, n1, n2, n3)
152152
v = array_from_host(vh)
153153

154-
s = AK.accumulate(+, v; init=Float32(0), dims=dims)
154+
s = AK.accumulate(+, v; init=Float32(0), dims)
155155
sh = Array(s)
156-
@test all(sh .≈ accumulate(+, vh, init=Float32(0), dims=dims))
156+
@test all(sh .≈ accumulate(+, vh; init=Float32(0), dims))
157157
end
158158
end
159159

@@ -166,9 +166,9 @@ end
166166
vh = rand(Float32, n1, n2, n3)
167167
v = array_from_host(vh)
168168
init = rand(-1000:1000)
169-
s = AK.accumulate(+, v; init=Float32(init), dims=dims)
169+
s = AK.accumulate(+, v; init=Float32(init), dims)
170170
sh = Array(s)
171-
@test all(sh .≈ accumulate(+, vh, init=Float32(init), dims=dims))
171+
@test all(sh .≈ accumulate(+, vh; init=Float32(init), dims))
172172
end
173173
end
174174

@@ -235,8 +235,8 @@ end
235235
# @test all(Array(AK.cumsum(v)) .== cumsum(vh))
236236

237237
# Along dimensions
238-
r = Array(AK.cumsum(v, dims=dims))
239-
rh = cumsum(vh, dims=dims)
238+
r = Array(AK.cumsum(v; dims))
239+
rh = cumsum(vh; dims)
240240

241241
@test r == rh
242242
end
@@ -281,8 +281,8 @@ end
281281
# @test all(Array(AK.cumprod(v)) .== cumprod(vh))
282282

283283
# Along dimensions
284-
r = Array(AK.cumprod(v, dims=dims))
285-
rh = cumprod(vh, dims=dims)
284+
r = Array(AK.cumprod(v; dims))
285+
rh = cumprod(vh; dims)
286286

287287
@test r == rh
288288
end

test/reduce.jl

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ end
153153
for ksize in 0:3
154154
sh = rand(Int32(1):Int32(100), isize, jsize, ksize)
155155
s = array_from_host(sh)
156-
d = AK.reduce(+, s; init=Int32(10), dims=dims)
156+
d = AK.reduce(+, s; init=Int32(10), dims)
157157
dh = Array(d)
158-
@test dh == sum(sh, init=Int32(10), dims=dims)
159-
@test eltype(dh) == eltype(sum(sh, init=Int32(10), dims=dims))
158+
@test dh == sum(sh; init=Int32(10), dims)
159+
@test eltype(dh) == eltype(sum(sh; init=Int32(10), dims))
160160
end
161161
end
162162
end
@@ -170,9 +170,9 @@ end
170170
n3 = rand(1:100)
171171
vh = rand(Int32(1):Int32(100), n1, n2, n3)
172172
v = array_from_host(vh)
173-
s = AK.reduce(+, v; init=Int32(0), dims=dims)
173+
s = AK.reduce(+, v; init=Int32(0), dims)
174174
sh = Array(s)
175-
@test sh == sum(vh, dims=dims)
175+
@test sh == sum(vh; dims)
176176
end
177177
end
178178

@@ -183,9 +183,9 @@ end
183183
n3 = rand(1:100)
184184
vh = rand(UInt32(1):UInt32(100), n1, n2, n3)
185185
v = array_from_host(vh)
186-
s = AK.reduce(+, v; init=UInt32(0), dims=dims)
186+
s = AK.reduce(+, v; init=UInt32(0), dims)
187187
sh = Array(s)
188-
@test sh == sum(vh, dims=dims)
188+
@test sh == sum(vh; dims)
189189
end
190190
end
191191

@@ -196,9 +196,9 @@ end
196196
n3 = rand(1:100)
197197
vh = rand(Float32, n1, n2, n3)
198198
v = array_from_host(vh)
199-
s = AK.reduce(+, v; init=Float32(0), dims=dims)
199+
s = AK.reduce(+, v; init=Float32(0), dims)
200200
sh = Array(s)
201-
@test sh sum(vh, dims=dims)
201+
@test sh sum(vh; dims)
202202
end
203203
end
204204

@@ -211,9 +211,9 @@ end
211211
vh = rand(Int32(1):Int32(100), n1, n2, n3)
212212
v = array_from_host(vh)
213213
init = rand(1:100)
214-
s = AK.reduce(+, v; init=Int32(init), dims=dims)
214+
s = AK.reduce(+, v; init=Int32(init), dims)
215215
sh = Array(s)
216-
@test sh == reduce(+, vh, dims=dims, init=init)
216+
@test sh == reduce(+, vh; dims, init)
217217
end
218218
end
219219

@@ -321,7 +321,7 @@ end
321321
init = rand(1:100)
322322
s = AK.mapreduce(abs, +, v; switch_below=switch_below, init=Int32(init))
323323
vh = Array(v)
324-
@test s == mapreduce(abs, +, vh, init=init)
324+
@test s == mapreduce(abs, +, vh; init)
325325
end
326326

327327
# Test with unmaterialised ranges
@@ -363,10 +363,10 @@ end
363363
for ksize in 0:3
364364
sh = rand(Int32(-100):Int32(100), isize, jsize, ksize)
365365
s = array_from_host(sh)
366-
d = AK.mapreduce(-, +, s; init=Int32(-10), dims=dims)
366+
d = AK.mapreduce(-, +, s; init=Int32(-10), dims)
367367
dh = Array(d)
368-
@test dh == mapreduce(-, +, sh, init=Int32(-10), dims=dims)
369-
@test eltype(dh) == eltype(mapreduce(-, +, sh, init=Int32(-10), dims=dims))
368+
@test dh == mapreduce(-, +, sh; init=Int32(-10), dims)
369+
@test eltype(dh) == eltype(mapreduce(-, +, sh; init=Int32(-10), dims))
370370
end
371371
end
372372
end
@@ -380,9 +380,9 @@ end
380380
n3 = rand(1:100)
381381
vh = rand(Int32(1):Int32(100), n1, n2, n3)
382382
v = array_from_host(vh)
383-
s = AK.mapreduce(-, +, v; init=Int32(0), dims=dims)
383+
s = AK.mapreduce(-, +, v; init=Int32(0), dims)
384384
sh = Array(s)
385-
@test sh == mapreduce(-, +, vh, init=Int32(0), dims=dims)
385+
@test sh == mapreduce(-, +, vh; init=Int32(0), dims)
386386
end
387387
end
388388

@@ -394,7 +394,7 @@ end
394394
s;
395395
init=(typemax(Float32), typemax(Float32)),
396396
neutral=(typemax(Float32), typemax(Float32)),
397-
dims=dims,
397+
dims,
398398
)
399399
end
400400

@@ -405,7 +405,7 @@ end
405405
(a, b) -> (a[1] < b[1] ? a[1] : b[1], a[2] < b[2] ? a[2] : b[2]),
406406
s;
407407
init=(typemax(Float32), typemax(Float32)),
408-
dims=dims,
408+
dims,
409409
)
410410
end
411411

@@ -439,9 +439,9 @@ end
439439
vh = rand(Int32(-100):Int32(100), n1, n2, n3)
440440
v = array_from_host(vh)
441441
init = rand(1:100)
442-
s = AK.mapreduce(-, +, v; init=Int32(init), dims=dims)
442+
s = AK.mapreduce(-, +, v; init=Int32(init), dims)
443443
sh = Array(s)
444-
@test sh == mapreduce(-, +, vh, dims=dims, init=init)
444+
@test sh == mapreduce(-, +, vh; dims, init)
445445
end
446446
end
447447

@@ -502,8 +502,8 @@ end
502502
@test AK.sum(v) == sum(vh)
503503

504504
# Along dimensions
505-
r = Array(AK.sum(v, dims=dims))
506-
rh = sum(vh, dims=dims)
505+
r = Array(AK.sum(v; dims))
506+
rh = sum(vh; dims)
507507

508508
@test r == rh
509509
end
@@ -544,8 +544,8 @@ end
544544
@test AK.sum(v) == sum(vh)
545545

546546
# Along dimensions
547-
r = Array(AK.sum(v, dims=dims))
548-
rh = sum(vh, dims=dims)
547+
r = Array(AK.sum(v; dims))
548+
rh = sum(vh; dims)
549549

550550
@test r == rh
551551
end
@@ -586,8 +586,8 @@ end
586586
@test AK.minimum(v) == minimum(vh)
587587

588588
# Along dimensions
589-
r = Array(AK.minimum(v, dims=dims))
590-
rh = minimum(vh, dims=dims)
589+
r = Array(AK.minimum(v; dims))
590+
rh = minimum(vh; dims)
591591

592592
@test r == rh
593593
end
@@ -628,8 +628,8 @@ end
628628
@test AK.maximum(v) == maximum(vh)
629629

630630
# Along dimensions
631-
r = Array(AK.maximum(v, dims=dims))
632-
rh = maximum(vh, dims=dims)
631+
r = Array(AK.maximum(v; dims))
632+
rh = maximum(vh; dims)
633633

634634
@test r == rh
635635
end
@@ -670,8 +670,8 @@ end
670670
@test AK.count(x->x>0.5, v) == count(x->x>0.5, vh)
671671

672672
# Along dimensions
673-
r = Array(AK.count(x->x>0.5, v, dims=dims))
674-
rh = count(x->x>0.5, vh, dims=dims)
673+
r = Array(AK.count(x->x>0.5, v; dims))
674+
rh = count(x->x>0.5, vh; dims)
675675

676676
@test r == rh
677677
end

0 commit comments

Comments
 (0)