1
1
2
2
using SliceMap
3
3
using Test
4
- using ForwardDiff, Tracker, Zygote, TensorCast
4
+ using ForwardDiff, Tracker, Zygote, TensorCast, JuliennedArrays
5
5
6
6
Zygote. refresh ()
7
7
@@ -26,8 +26,12 @@ Zygote.refresh()
26
26
@test res ≈ tcm (mat)
27
27
@test grad ≈ Zygote. gradient (m -> sum (sin, tcm (m)), mat)[1 ]
28
28
29
+ jcols (f,m) = Align (map (f, Slices (m, True (), False ())), True (), False ())
30
+ @test res ≈ jcols (fun, mat)
31
+ @test grad ≈ Zygote. gradient (m -> sum (sin, jcols (fun, m)), mat)[1 ]
32
+
29
33
end
30
- @testset " columns, scalar" begin
34
+ @testset " columns -> scalar" begin
31
35
32
36
mat = rand (1 : 9 , 3 ,10 )
33
37
fun (x) = sum (x) # different function!
49
53
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm3(m)), mat)[1]
50
54
51
55
end
52
- @testset " columns, matrix" begin
56
+ @testset " columns -> matrix" begin
53
57
54
58
mat = rand (1 : 9 , 3 ,10 )
55
59
fun (x) = x .* x' # different function! vector -> matrix
87
91
# @test res ≈ tcm2(mat)
88
92
# @test grad ≈ Zygote.gradient(m -> sum(sin, tcm2(m)), mat)[1]
89
93
94
+ jrows (f,m) = Align (map (f, Slices (m, False (), True ())), False (), True ())
95
+ @test res ≈ jrows (fun, mat)
96
+ @test grad ≈ Zygote. gradient (m -> sum (sin, jrows (fun, m)), mat)[1 ]
97
+
98
+
90
99
end
91
- @testset " slices" begin
100
+ @testset " slices of a 4-tensor " begin
92
101
93
102
ten = randn (3 ,4 ,5 ,2 )
94
- fun (x) = sqrt (3 ) .+ x.^ 3 ./ (sum (x)^ 2 )
103
+ fun (x:: AbstractVector ) = sqrt (3 ) .+ x.^ 3 ./ (sum (x)^ 2 )
95
104
res = mapslices (fun, ten, dims= 3 )
96
105
97
106
@test res ≈ slicemap (fun, ten, dims= 3 )
98
107
99
108
grad = ForwardDiff. gradient (x -> sum (sin, slicemap (fun, x, dims= 3 )), ten)
100
109
@test grad ≈ Zygote. gradient (x -> sum (sin, slicemap (fun, x, dims= 3 )), ten)[1 ]
101
110
111
+ jthree (f,m) = Align (map (f,
112
+ Slices (m, False (), False (), True (), False ())
113
+ ), False (), False (), True (), False ())
114
+ @test res ≈ jthree (fun, ten)
115
+ @test grad ≈ Zygote. gradient (m -> sum (sin, jthree (fun, m)), ten)[1 ]
116
+
102
117
end
0 commit comments