|
57 | 57 | @adjoint function Zygote.literal_getproperty(f::BaseField{B,M,T}, ::Val{:arr}) where {B<:SpatialBasis{AzFourier},M,T} |
58 | 58 | getfield(f,:arr), Δ -> (BaseField{B}(Δ ./ adapt(typeof(Δ), T.(rfft_degeneracy_fac(f.Nx)' ./ Zfac(B(), f.metadata))), f.metadata),) |
59 | 59 | end |
60 | | -# preserve field type for sub-component property getters |
61 | | -function _getproperty_subcomponent_pullback(f, k) |
62 | | - function getproperty_pullback(Δ) |
63 | | - g = similar(f, promote_type(eltype(f), eltype(Δ))) |
64 | | - g .= 0 |
| 60 | +# needed to preserve field type for sub-component property getters |
| 61 | +@adjoint function Zygote.getproperty(f::BaseField, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B,:P,:IP)))...}) |
| 62 | + function field_getproperty_pullback(Δ) |
| 63 | + g = (similar(f, promote_type(eltype(f), eltype(Δ))) .= 0) |
65 | 64 | getproperty(g, k) .= Δ |
66 | 65 | (g, nothing) |
67 | 66 | end |
68 | | - getproperty(f, k), getproperty_pullback |
69 | | -end |
70 | | -@adjoint function Zygote.literal_getproperty(f::BaseField{B}, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B)))...}) where {B₀, B<:SpatialBasis{B₀}} |
71 | | - _getproperty_subcomponent_pullback(f, k) |
72 | | -end |
73 | | -@adjoint function Zygote.literal_getproperty(f::BaseS02{Basis3Prod{𝐈,B₂,B₀}}, k::Val{:P}) where {B₂,B₀} |
74 | | - _getproperty_subcomponent_pullback(f, k) |
| 67 | + getproperty(f, k), field_getproperty_pullback |
75 | 68 | end |
76 | 69 | # if accumulting from one branch that was just a f.metadata |
77 | 70 | Zygote.accum(f::BaseField, nt::NamedTuple{(:arr,:metadata)}) = (@assert(isnothing(nt.arr)); f) |
|
0 commit comments