@@ -36,7 +36,9 @@ lastindex(f::BaseField, i::Int) = lastindex(f.arr, i)
3636@propagate_inbounds getindex (f:: BaseField , I:: Union{Int,Colon,AbstractArray} ...) = getindex (f. arr, I... )
3737@propagate_inbounds setindex! (f:: BaseField , X, I:: Union{Int,Colon,AbstractArray} ...) = (setindex! (f. arr, X, I... ); f)
3838similar (f:: BaseField{B} , :: Type{T} ) where {B,T} = BaseField {B} (similar (f. arr, T), f. metadata)
39+ similar (f:: BaseField{B} , :: Type{T} , dims:: Base.DimOrInd... ) where {B,T} = similar (f. arr, T, dims... )
3940copy (f:: BaseField{B} ) where {B} = BaseField {B} (copy (f. arr), f. metadata)
41+ copyto! (dst:: AbstractArray , src:: BaseField ) = copyto! (dst, src. arr)
4042(== )(f₁:: BaseField , f₂:: BaseField ) = strict_compatible_metadata (f₁,f₂) && (f₁. arr == f₂. arr)
4143
4244
@@ -46,7 +48,9 @@ function promote(f₁::BaseField{B₁}, f₂::BaseField{B₂}) where {B₁,B₂}
4648 B = typeof (promote_basis_generic (B₁ (), B₂ ()))
4749 B (f₁), B (f₂)
4850end
49-
51+ # allow very basic arithmetic with BaseField & AbstractArray
52+ promote (f:: BaseField{B} , x:: AbstractArray ) where {B} = (f, BaseField {B} (reshape (x, size (f. arr)), f. proj))
53+ promote (x:: AbstractArray , f:: BaseField{B} ) where {B} = reverse (promote (f, x))
5054
5155# # broadcasting
5256
@@ -61,6 +65,7 @@ BroadcastStyle(::Type{F}) where {B,M,T,A,F<:BaseField{B,M,T,A}} =
6165BroadcastStyle (:: BaseFieldStyle{S₁,B₁} , :: BaseFieldStyle{S₂,B₂} ) where {S₁,B₁,S₂,B₂} =
6266 BaseFieldStyle {typeof(result_style(S₁(), S₂())), typeof(promote_basis_strict(B₁(),B₂()))} ()
6367BroadcastStyle (S:: BaseFieldStyle , :: DefaultArrayStyle{0} ) = S
68+ BaseFieldStyle {S,B} (:: Val{2} ) where {S,B} = DefaultArrayStyle {2} ()
6469
6570# with the Broadcasted object created, we now compute the answer
6671function materialize (bc:: Broadcasted{BaseFieldStyle{S,B}} ) where {S,B}
@@ -101,10 +106,13 @@ function materialize!(dst::BaseField{B}, bc::Broadcasted{BaseFieldStyle{S,B′}}
101106
102107end
103108
104- # the default preprocessing, which just unwraps the underlying array.
105- # this doesn't dispatch on the first argument, but custom BaseFields
106- # are free to override this and dispatch on it if they need
107- preprocess (:: Any , f:: BaseField ) = f. arr
109+ # if broadcasting into a BaseField, the first method here is hit with
110+ # dest::Tuple{BaseFieldStyle,M}, in which case just unwrap the array,
111+ # since it will be fed into a downstream regular broadcast
112+ preprocess (:: Tuple{BaseFieldStyle{S,B},M} , f:: BaseField ) where {S,B,M} = f. arr
113+ # if broadcasting into an Array (ie dropping the BaseField wrapper) we
114+ # need to return the vector representation
115+ preprocess (:: AbstractArray , f:: BaseField ) = view (f. arr, :)
108116
109117# we re-wrap each Broadcasted object as we go through preprocessing
110118# because some array types do special things here (e.g. CUDA wraps
@@ -135,8 +143,7 @@ function strict_compatible_metadata(f₁::BaseField, f₂::BaseField)
135143end
136144
137145# # mapping
138-
139- # this comes up in Zygote.broadcast_forward, and the generic falls back to a regular Array
146+ # map over entries in the array like a true AbstractArray
140147map (func, f:: BaseField{B} ) where {B} = BaseField {B} (map (func, f. arr), f. metadata)
141148
142149
@@ -169,4 +176,4 @@ getproperty(f::BaseField{B}, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B)))...}) where
169176 BaseField {B₀} (_reshape_batch (view (getfield (f,:arr ), pol_slice (f, pol_index (B (), k))... )), getfield (f,:metadata ))
170177getproperty (f:: BaseS02{Basis3Prod{𝐈,B₂,B₀}} , :: Val{:P} ) where {B₂,B₀} =
171178 BaseField {Basis2Prod{B₂,B₀}} (view (getfield (f,:arr ), pol_slice (f, 2 : 3 )... ), getfield (f,:metadata ))
172- getproperty (f:: BaseS2 , :: Val{:P} ) = f
179+ getproperty (f:: BaseS2 , :: Val{:P} ) = f
0 commit comments