|
14 | 14 |
|
15 | 15 | loss(x, y) = Flux.mse(m(x), y)
|
16 | 16 | data = [(𝐱, rand(Float32, 128, 1024, 5))]
|
17 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 17 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
18 | 18 | end
|
19 | 19 |
|
20 | 20 | @testset "permuted 1D OperatorConv" begin
|
|
34 | 34 |
|
35 | 35 | loss(x, y) = Flux.mse(m(x), y)
|
36 | 36 | data = [(𝐱, rand(Float32, 1024, 128, 5))]
|
37 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 37 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
38 | 38 | end
|
39 | 39 |
|
40 | 40 | @testset "1D OperatorKernel" begin
|
|
52 | 52 |
|
53 | 53 | loss(x, y) = Flux.mse(m(x), y)
|
54 | 54 | data = [(𝐱, rand(Float32, 128, 1024, 5))]
|
55 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 55 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
56 | 56 | end
|
57 | 57 |
|
58 | 58 | @testset "permuted 1D OperatorKernel" begin
|
|
71 | 71 |
|
72 | 72 | loss(x, y) = Flux.mse(m(x), y)
|
73 | 73 | data = [(𝐱, rand(Float32, 1024, 128, 5))]
|
74 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 74 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
75 | 75 | end
|
76 | 76 |
|
77 | 77 | @testset "2D OperatorConv" begin
|
|
89 | 89 |
|
90 | 90 | loss(x, y) = Flux.mse(m(x), y)
|
91 | 91 | data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
|
92 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 92 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
93 | 93 | end
|
94 | 94 |
|
95 | 95 | @testset "permuted 2D OperatorConv" begin
|
|
108 | 108 |
|
109 | 109 | loss(x, y) = Flux.mse(m(x), y)
|
110 | 110 | data = [(𝐱, rand(Float32, 22, 22, 64, 5))]
|
111 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 111 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
112 | 112 | end
|
113 | 113 |
|
114 | 114 | @testset "2D OperatorKernel" begin
|
|
125 | 125 |
|
126 | 126 | loss(x, y) = Flux.mse(m(x), y)
|
127 | 127 | data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
|
128 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 128 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
129 | 129 | end
|
130 | 130 |
|
131 | 131 | @testset "permuted 2D OperatorKernel" begin
|
|
143 | 143 |
|
144 | 144 | loss(x, y) = Flux.mse(m(x), y)
|
145 | 145 | data = [(𝐱, rand(Float32, 22, 22, 64, 5))]
|
146 |
| - Flux.train!(loss, params(m), data, Flux.ADAM()) |
| 146 | + Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
147 | 147 | end
|
148 | 148 |
|
149 | 149 | @testset "SpectralConv" begin
|
|
165 | 165 | graph = grid([10, 10])
|
166 | 166 | 𝐱 = rand(Float32, channel, N, batch_size)
|
167 | 167 | l = WithGraph(FeaturedGraph(graph), GraphKernel(κ, channel))
|
168 |
| - @test repr(l.layer) == "GraphKernel(Dense(64, 32, relu), channel=32)" |
| 168 | + @test repr(l.layer) == "GraphKernel(Dense(64 => 32, relu), channel=32)" |
169 | 169 | @test size(l(𝐱)) == (channel, N, batch_size)
|
170 | 170 |
|
171 | 171 | g = Zygote.gradient(() -> sum(l(𝐱)), Flux.params(l))
|
|
0 commit comments