Skip to content

Commit ede7319

Browse files
author
Michael Abbott
committed
args bug + tests
1 parent d328e5e commit ede7319

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/SliceMap.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ surevec(A) = vec(A) # to allow f vector -> matrix, by reshaping
3737
_mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
3838

3939
@grad _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
40-
∇mapcols(map, map(col -> Tracker.forward(x -> surevec(f(x, args...)), col), eachcol(data(M))), args)
40+
∇mapcols(map, map(col -> Tracker.forward(x -> surevec(f(x, args...)), col), eachcol(data(M))), args...)
4141

42-
function ∇mapcols(bigmap, forwards, args)
42+
function ∇mapcols(bigmap, forwards, args...)
4343
reduce(hcat, map(datafirst, forwards)), Δ -> begin
4444
cols = bigmap((fwd, Δcol) -> data(last(fwd)(Δcol)[1]), forwards, eachcol(data(Δ)))
4545
(nothing, nothing, reduce(hcat, cols), map(_->nothing, args)...)
@@ -220,12 +220,15 @@ end
220220
# What KissThreading does is much more complicated, perhaps worth investigating:
221221
# https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl
222222

223+
# BTW I do the first one because some diffeq maps are infer to ::Any
224+
# else you could use Core.Compiler.return_type(f, Tuple{eltype(x)})
225+
223226
"""
224227
threadmap(f, A)
225228
threadmap(f, A, B)
226229
227230
Simple version of `map` using a `Threads.@threads` loop;
228-
only for vectors & only two of them, of nonzero length,
231+
only for vectors & really at most two of them, of nonzero length,
229232
with all outputs having the same type.
230233
"""
231234
function threadmap(f::Function, vw::AbstractVector...)

test/runtests.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,32 @@ end
8383
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
8484
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
8585

86-
# tcm4(mat) = @cast out[i⊗i′,j] := fun(mat[:,j])[i,i′] i:3, i′:3 # changed here too
86+
# tcm4(mat) = @cast out[i⊗i′,j] := fun(mat[:,j])[i,i′] i:3
8787
# @test res ≈ tcm4(mat)
8888
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm4(m)), mat)[1]
8989

90+
end
91+
@testset "columns w args" begin
92+
93+
mat = randn(Float32, 3,10)
94+
fun(x, s) = 1 .+ x .* s
95+
res = mapslices(x -> vec(fun(x,5)), mat, dims=1)
96+
97+
@test res mapcols(fun, mat, 5)
98+
@test res MapCols{3}(fun, mat, 5)
99+
100+
grad = ForwardDiff.gradient(m -> sum(sin, mapslices(x -> vec(fun(x,5)), m, dims=1)), mat)
101+
102+
@test grad Tracker.gradient(m -> sum(sin, mapcols(fun, m, 5)), mat)[1]
103+
@test grad Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m, 5)), mat)[1]
104+
105+
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m, 5)), mat)[1]
106+
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m, 5)), mat)[1]
107+
108+
# tcm5(mat) = @cast out[i,j] := fun(mat[:,j], 5)[i]
109+
# @test res ≈ tcm5(mat)
110+
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm5(m)), mat)[1]
111+
90112
end
91113
@testset "rows" begin
92114

@@ -108,7 +130,6 @@ end
108130
@test res jrows(fun, mat)
109131
@test grad Zygote.gradient(m -> sum(sin, jrows(fun, m)), mat)[1]
110132

111-
112133
end
113134
@testset "slices of a 4-tensor" begin
114135

0 commit comments

Comments
 (0)