|
46 | 46 |
|
47 | 47 | @testset "real input" begin |
48 | 48 | # even though our rule was define in terms of Wirtinger, |
49 | | - # pushforward result will be real as real (even if seed is Compex) |
| 49 | + # pushforward result will be real as real (even if seed is Complex) |
50 | 50 |
|
51 | | - x = rand(Float64) |
| 51 | + x = 5.0 |
52 | 52 | f, myabs2_pushforward = frule(myabs2, x) |
53 | 53 | @test f === x^2 |
54 | 54 |
|
55 | 55 | Δ = One() |
56 | 56 | df = @inferred myabs2_pushforward(NamedTuple(), Δ) |
57 | 57 | @test df === x + x |
58 | 58 |
|
59 | | - Δ = rand(Complex{Int64}) |
| 59 | + Δ = 2.0 + 3.0im |
60 | 60 | df = @inferred myabs2_pushforward(NamedTuple(), Δ) |
61 | | - @test df === Δ * (x + x) |
| 61 | + @test df === (Δ + conj(Δ)) * x |
62 | 62 | end |
63 | 63 |
|
64 | 64 | @testset "complex input" begin |
65 | | - z = rand(Complex{Float64}) |
| 65 | + z = 5.0 + 7.0im |
66 | 66 | f, myabs2_pushforward = frule(myabs2, z) |
67 | 67 | @test f === abs2(z) |
68 | 68 |
|
69 | 69 | df = @inferred myabs2_pushforward(NamedTuple(), One()) |
70 | 70 | @test df === Wirtinger(z', z) |
71 | 71 |
|
72 | | - Δ = rand(Complex{Int64}) |
| 72 | + Δ = 2.0 + 3.0im |
73 | 73 | df = @inferred myabs2_pushforward(NamedTuple(), Δ) |
74 | | - @test df === Wirtinger(Δ * z', Δ * z) |
| 74 | + @test df === Wirtinger(Δ * conj(z), conj(Δ) * z) |
75 | 75 | end |
76 | 76 | end |
77 | 77 |
|
@@ -134,11 +134,11 @@ end |
134 | 134 | fx, f_pushforward = res |
135 | 135 | df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) |
136 | 136 |
|
137 | | - df_dx::Thunk = df(One(), Zero()) |
138 | | - df_dp::Thunk = df(Zero(), One()) |
| 137 | + df_dx = df(One(), Zero()) |
| 138 | + df_dp = df(Zero(), One()) |
139 | 139 | @test fx == f(x, p) # Check we still get the normal value, right |
140 | | - @test df_dx() isa expected_type_df_dx |
141 | | - @test df_dp() isa expected_type_df_dp |
| 140 | + @test df_dx isa expected_type_df_dx |
| 141 | + @test df_dp isa expected_type_df_dp |
142 | 142 |
|
143 | 143 |
|
144 | 144 | res = rrule(f, x, p) |
|
147 | 147 | dself, df_dx, df_dp = f_pullback(One()) |
148 | 148 | @test fx == f(x, p) # Check we still get the normal value, right |
149 | 149 | @test dself == NO_FIELDS |
150 | | - @test df_dx() isa expected_type_df_dx |
151 | | - @test df_dp() isa expected_type_df_dp |
| 150 | + @test df_dx isa expected_type_df_dx |
| 151 | + @test df_dp isa expected_type_df_dp |
152 | 152 | end |
153 | 153 | end |
0 commit comments