|
1 | 1 |
|
2 | 2 | using SliceMap
|
3 | 3 | using Test
|
4 |
| -using ForwardDiff, Tracker, Zygote, TensorCast, JuliennedArrays |
| 4 | +using ForwardDiff, Tracker, Zygote, JuliennedArrays |
5 | 5 |
|
6 | 6 | Zygote.refresh()
|
7 | 7 |
|
@@ -35,10 +35,6 @@ Zygote.refresh()
|
35 | 35 | @test grad ≈ Zygote.gradient(m -> sum(sin, tmapcols(fun, m)), mat)[1]
|
36 | 36 | @test grad ≈ Zygote.gradient(m -> sum(sin, ThreadMapCols{3}(fun, m)), mat)[1]
|
37 | 37 |
|
38 |
| - tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i] |
39 |
| - @test res ≈ tcm(mat) |
40 |
| - @test grad ≈ Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1] |
41 |
| - |
42 | 38 | jcols(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
|
43 | 39 | @test res ≈ jcols(fun, mat)
|
44 | 40 | @test grad ≈ Zygote.gradient(m -> sum(sin, jcols(fun, m)), mat)[1]
|
|
61 | 57 | @test grad ≈ Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
|
62 | 58 | @test grad ≈ Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
|
63 | 59 |
|
64 |
| - # tcm3(mat) = @cast out[_,j] := fun(mat[:,j]) # changed here too |
65 |
| - # @test res ≈ tcm3(mat) |
66 |
| - # @test grad ≈ Zygote.gradient(m -> sum(sin, tcm3(m)), mat)[1] |
67 |
| - |
68 | 60 | end
|
69 | 61 | @testset "columns -> matrix" begin
|
70 | 62 |
|
|
83 | 75 | @test grad ≈ Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
|
84 | 76 | @test grad ≈ Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
|
85 | 77 |
|
86 |
| - # tcm4(mat) = @cast out[i⊗i′,j] := fun(mat[:,j])[i,i′] i:3 |
87 |
| - # @test res ≈ tcm4(mat) |
88 |
| - # @test grad ≈ Zygote.gradient(m -> sum(sin, tcm4(m)), mat)[1] |
89 |
| - |
90 | 78 | end
|
91 | 79 | @testset "columns w args" begin
|
92 | 80 |
|
|
105 | 93 | @test grad ≈ Zygote.gradient(m -> sum(sin, mapcols(fun, m, 5)), mat)[1]
|
106 | 94 | @test grad ≈ Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m, 5)), mat)[1]
|
107 | 95 |
|
108 |
| - # tcm5(mat) = @cast out[i,j] := fun(mat[:,j], 5)[i] |
109 |
| - # @test res ≈ tcm5(mat) |
110 |
| - # @test grad ≈ Zygote.gradient(m -> sum(sin, tcm5(m)), mat)[1] |
111 |
| - |
112 | 96 | end
|
113 | 97 | @testset "rows" begin
|
114 | 98 |
|
|
122 | 106 | @test grad ≈ Tracker.gradient(m -> sum(sin, maprows(fun, m)), mat)[1]
|
123 | 107 | @test grad ≈ Zygote.gradient(m -> sum(sin, maprows(fun, m)), mat)[1]
|
124 | 108 |
|
125 |
| - # tcm2(mat) = @cast out[i,j] := fun(mat[i,:])[j] |
126 |
| - # @test res ≈ tcm2(mat) |
127 |
| - # @test grad ≈ Zygote.gradient(m -> sum(sin, tcm2(m)), mat)[1] |
128 |
| - |
129 | 109 | jrows(f,m) = Align(map(f, Slices(m, False(), True())), False(), True())
|
130 | 110 | @test res ≈ jrows(fun, mat)
|
131 | 111 | @test grad ≈ Zygote.gradient(m -> sum(sin, jrows(fun, m)), mat)[1]
|
|
0 commit comments