|
178 | 178 |
|
179 | 179 | @test_throws DimensionMismatch push!(stack([rand(2)]), rand(3))
|
180 | 180 |
|
181 |
| -end |
182 |
| -@info "loading Zygote" |
183 |
| -using Zygote |
184 |
| -@testset "zygote" begin |
185 |
| - |
186 |
| - @test Zygote.gradient((x,y) -> sum(stack(x,y)), ones(2), ones(2)) == ([1,1], [1,1]) |
187 |
| - @test Zygote.gradient((x,y) -> sum(stack([x,y])), ones(2), ones(2)) == ([1,1], [1,1]) |
188 |
| - |
189 |
| - f399(x) = sum(stack(x) * sum(x)) |
190 |
| - f399c(x) = sum(collect(stack(x)) * sum(x)) |
191 |
| - @test Zygote.gradient(f399, [ones(2), ones(2)]) == ([[4,4], [4,4]],) |
192 |
| - @test Zygote.gradient(f399c, [ones(2), ones(2)]) == ([[4,4], [4,4]],) |
193 |
| - ftup(x) = sum(stack(x...) * sum(x)) |
194 |
| - ftupc(x) = sum(collect(stack(x...)) * sum(x)) |
195 |
| - @test Zygote.gradient(ftup, (ones(2), ones(2))) == (([4,4], [4,4]),) |
196 |
| - @test Zygote.gradient(ftupc, (ones(2), ones(2))) == (([4,4], [4,4]),) |
197 |
| - |
198 | 181 | end
|
199 | 182 | @testset "readme" begin
|
200 | 183 |
|
|
223 | 206 | @test rstack([1,2], 1:3) == [1 1; 2 2; 0 3]
|
224 | 207 | @test rstack([[1,2], 1:3], fill=99) == [1 1; 2 2; 99 3]
|
225 | 208 |
|
226 |
| - @test rstack(1:2, OffsetArray([2,3], 2:3)) == [1 0; 2 2; 0 3] |
227 |
| - @test rstack(1:2, OffsetArray([0.1,1], 0:1)) == OffsetArray([0 0.1; 1 1.0; 2 0],-1,0) |
| 209 | + @test rstack(1:2, OffsetArray([2,3], +1)) == [1 0; 2 2; 0 3] |
| 210 | + @test rstack(1:2, OffsetArray([0.1,1], -1)) == OffsetArray([0 0.1; 1 1.0; 2 0],-1,0) |
| 211 | + |
| 212 | + @test dimnames(rstack(:b, NamedDimsArray(1:2, :a), OffsetArray([2,3], +1))) == (:a, :b) |
| 213 | + |
| 214 | +end |
| 215 | +@info "loading Zygote" |
| 216 | +using Zygote |
| 217 | +@testset "zygote" begin |
| 218 | + |
| 219 | + @test Zygote.gradient((x,y) -> sum(stack(x,y)), ones(2), ones(2)) == ([1,1], [1,1]) |
| 220 | + @test Zygote.gradient((x,y) -> sum(stack([x,y])), ones(2), ones(2)) == ([1,1], [1,1]) |
| 221 | + |
| 222 | + f399(x) = sum(stack(x) * sum(x)) |
| 223 | + f399c(x) = sum(collect(stack(x)) * sum(x)) |
| 224 | + @test Zygote.gradient(f399, [ones(2), ones(2)]) == ([[4,4], [4,4]],) |
| 225 | + @test Zygote.gradient(f399c, [ones(2), ones(2)]) == ([[4,4], [4,4]],) |
| 226 | + ftup(x) = sum(stack(x...) * sum(x)) |
| 227 | + ftupc(x) = sum(collect(stack(x...)) * sum(x)) |
| 228 | + @test Zygote.gradient(ftup, (ones(2), ones(2))) == (([4,4], [4,4]),) |
| 229 | + @test Zygote.gradient(ftupc, (ones(2), ones(2))) == (([4,4], [4,4]),) |
228 | 230 |
|
229 | 231 | end
|
0 commit comments