Skip to content

Commit 477c388

Browse files
author
Michael Abbott
committed
+ JuliennedArrays
1 parent b9dd221 commit 477c388

File tree

5 files changed

+65
-17
lines changed

5 files changed

+65
-17
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
JuliennedArrays = "5cadff95-7770-533d-a838-a1bf817ee6e0"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

README.md

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ maprows(f, M) ≈ mapreduce(f, vcat, eachrow(M))
1414
slicemap(f, A; dims) mapslices(f, A, dims)
1515
```
1616

17-
### An example
17+
### Simple example
1818

1919
```julia
2020
mat = rand(1:9, 3,10)
@@ -56,27 +56,44 @@ mat1k = rand(3,1000);
5656
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 245.550 μs
5757
```
5858

59-
It also provides Zygote gradients for the slice/glue functions in
59+
### Other packages
60+
61+
This package also provides Zygote gradients for the slice/glue functions in
6062
[TensorCast](https://github.com/mcabbott/TensorCast.jl),
6163
which can be used to write many mapslices-like operations.
6264
(The function `slicemap(f, A, dims)` uses these functions, without having to write index notation.)
6365

6466
```julia
6567
using TensorCast
66-
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols
68+
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols
6769

6870
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
6971
Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
7072

71-
@btime tcm($mat1k) # 407.176 μs
72-
@btime Zygote.gradient(m -> sum(sin, tcm(m)), $mat1k) # 19.086 ms
73+
@btime tcm($mat1k) # 407.176 μs
74+
@btime Zygote.gradient(m -> sum(sin, tcm(m)), $mat1k); # 19.086 ms
75+
```
76+
77+
Similar gradients work for the Slice/Align functions in
78+
[JuliennedArrays](https://github.com/bramtayl/JuliennedArrays.jl),
79+
so it defines these too:
80+
81+
```julia
82+
using JuliennedArrays
83+
jumap(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
84+
jumap(fun, mat) # same as mapcols
85+
Zygote.gradient(m -> sum(sin, jumap(fun, m)), mat)[1]
86+
87+
@btime jumap(fun, $mat1k); # 408.259 μs
88+
@btime Zygote.gradient(m -> sum(sin, jumap(fun, m)), $mat1k); # 18.638 ms
7389
```
7490

7591
### Elsewhere
7692

77-
Issues about mapslices:
93+
About mapslices:
7894
* https://github.com/FluxML/Zygote.jl/issues/92
7995
* https://github.com/FluxML/Flux.jl/issues/741
96+
* https://github.com/JuliaLang/julia/issues/29146
8097

8198
Other packages which define gradients of possible interest:
8299
* https://github.com/GiggleLiu/LinalgBackwards.jl

src/SliceMap.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module SliceMap
33

44
export mapcols, MapCols, maprows, slicemap
55

6-
using MacroTools, Requires, WeightedArrays, TensorCast, Tracker
6+
using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
77

8+
using Tracker
89
using Tracker: TrackedMatrix, track, @grad, data
910

1011
#========== Reverse, Eachslice ==========#

src/zygote.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@ using .Zygote: @adjoint, _zero, forward
1414

1515
#===== TensorCast =====#
1616

17+
# @adjoint function TensorCast.sliceview(A::AbstractArray, code::Tuple)
18+
# TensorCast.sliceview(A, code), Δ -> begin
19+
# dA = _zero(A)
20+
# foreach(copyto!, TensorCast.sliceview(dA, code), Δ)
21+
# (dA, nothing)
22+
# end
23+
# end
24+
1725
@adjoint function TensorCast.sliceview(A::AbstractArray, code::Tuple)
18-
TensorCast.sliceview(A, code), Δ -> begin
19-
dA = _zero(A)
20-
foreach(copyto!, TensorCast.sliceview(dA, code), Δ)
21-
(dA, nothing)
22-
end
26+
TensorCast.sliceview(A, code), Δ -> (TensorCast.glue(Δ, code), nothing)
2327
end
2428

2529
@adjoint function TensorCast.red_glue(A::AbstractArray, code::Tuple)
@@ -30,6 +34,16 @@ end
3034
TensorCast.copy_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
3135
end
3236

37+
#===== JuliennedArrays =====#
38+
39+
@adjoint function Slices(whole, along...)
40+
Slices(whole, along...), Δ -> (Align(Δ, along...), map(_->nothing, along)...)
41+
end
42+
43+
@adjoint function Align(whole, along...)
44+
Align(whole, along...), Δ -> (Slices(Δ, along...), map(_->nothing, along)...)
45+
end
46+
3347
#===== Misc Base =====#
3448

3549
@adjoint function Base.reduce(::typeof(hcat), V::AbstractVector{<:AbstractVector})

test/runtests.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
using SliceMap
33
using Test
4-
using ForwardDiff, Tracker, Zygote, TensorCast
4+
using ForwardDiff, Tracker, Zygote, TensorCast, JuliennedArrays
55

66
Zygote.refresh()
77

@@ -26,8 +26,12 @@ Zygote.refresh()
2626
@test res tcm(mat)
2727
@test grad Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
2828

29+
jcols(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
30+
@test res jcols(fun, mat)
31+
@test grad Zygote.gradient(m -> sum(sin, jcols(fun, m)), mat)[1]
32+
2933
end
30-
@testset "columns, scalar" begin
34+
@testset "columns -> scalar" begin
3135

3236
mat = rand(1:9, 3,10)
3337
fun(x) = sum(x) # different function!
@@ -49,7 +53,7 @@ end
4953
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm3(m)), mat)[1]
5054

5155
end
52-
@testset "columns, matrix" begin
56+
@testset "columns -> matrix" begin
5357

5458
mat = rand(1:9, 3,10)
5559
fun(x) = x .* x' # different function! vector -> matrix
@@ -87,16 +91,27 @@ end
8791
# @test res ≈ tcm2(mat)
8892
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm2(m)), mat)[1]
8993

94+
jrows(f,m) = Align(map(f, Slices(m, False(), True())), False(), True())
95+
@test res jrows(fun, mat)
96+
@test grad Zygote.gradient(m -> sum(sin, jrows(fun, m)), mat)[1]
97+
98+
9099
end
91-
@testset "slices" begin
100+
@testset "slices of a 4-tensor" begin
92101

93102
ten = randn(3,4,5,2)
94-
fun(x) = sqrt(3) .+ x.^3 ./ (sum(x)^2)
103+
fun(x::AbstractVector) = sqrt(3) .+ x.^3 ./ (sum(x)^2)
95104
res = mapslices(fun, ten, dims=3)
96105

97106
@test res slicemap(fun, ten, dims=3)
98107

99108
grad = ForwardDiff.gradient(x -> sum(sin, slicemap(fun, x, dims=3)), ten)
100109
@test grad Zygote.gradient(x -> sum(sin, slicemap(fun, x, dims=3)), ten)[1]
101110

111+
jthree(f,m) = Align(map(f,
112+
Slices(m, False(), False(), True(), False())
113+
), False(), False(), True(), False())
114+
@test res jthree(fun, ten)
115+
@test grad Zygote.gradient(m -> sum(sin, jthree(fun, m)), ten)[1]
116+
102117
end

0 commit comments

Comments
 (0)