|
83 | 83 | @test grad ≈ Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
|
84 | 84 | @test grad ≈ Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
|
85 | 85 |
|
86 |
| - # tcm4(mat) = @cast out[i⊗i′,j] := fun(mat[:,j])[i,i′] i:3, i′:3 # changed here too |
| 86 | + # tcm4(mat) = @cast out[i⊗i′,j] := fun(mat[:,j])[i,i′] i:3 |
87 | 87 | # @test res ≈ tcm4(mat)
|
88 | 88 | # @test grad ≈ Zygote.gradient(m -> sum(sin, tcm4(m)), mat)[1]
|
89 | 89 |
|
| 90 | +end |
| 91 | +@testset "columns w args" begin |
| 92 | + |
| 93 | + mat = randn(Float32, 3,10) |
| 94 | + fun(x, s) = 1 .+ x .* s |
| 95 | + res = mapslices(x -> vec(fun(x,5)), mat, dims=1) |
| 96 | + |
| 97 | + @test res ≈ mapcols(fun, mat, 5) |
| 98 | + @test res ≈ MapCols{3}(fun, mat, 5) |
| 99 | + |
| 100 | + grad = ForwardDiff.gradient(m -> sum(sin, mapslices(x -> vec(fun(x,5)), m, dims=1)), mat) |
| 101 | + |
| 102 | + @test grad ≈ Tracker.gradient(m -> sum(sin, mapcols(fun, m, 5)), mat)[1] |
| 103 | + @test grad ≈ Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m, 5)), mat)[1] |
| 104 | + |
| 105 | + @test grad ≈ Zygote.gradient(m -> sum(sin, mapcols(fun, m, 5)), mat)[1] |
| 106 | + @test grad ≈ Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m, 5)), mat)[1] |
| 107 | + |
| 108 | + # tcm5(mat) = @cast out[i,j] := fun(mat[:,j], 5)[i] |
| 109 | + # @test res ≈ tcm5(mat) |
| 110 | + # @test grad ≈ Zygote.gradient(m -> sum(sin, tcm5(m)), mat)[1] |
| 111 | + |
90 | 112 | end
|
91 | 113 | @testset "rows" begin
|
92 | 114 |
|
|
108 | 130 | @test res ≈ jrows(fun, mat)
|
109 | 131 | @test grad ≈ Zygote.gradient(m -> sum(sin, jrows(fun, m)), mat)[1]
|
110 | 132 |
|
111 |
| - |
112 | 133 | end
|
113 | 134 | @testset "slices of a 4-tensor" begin
|
114 | 135 |
|
|
0 commit comments