11using Base. Broadcast: Broadcast as BC
2- using FillArrays: Zeros, fillsimilar
3- using TensorAlgebra: TensorAlgebra, * ₗ, + ₗ, - ₗ, / ₗ, conjed
2+ using TensorAlgebra: TensorAlgebra
43
54struct SectorStyle{I, N} <: BC.AbstractArrayStyle{N} end
65SectorStyle {I, N} (:: Val{M} ) where {I, N, M} = SectorStyle {I, M} ()
@@ -31,21 +30,17 @@ function Base.Broadcast.materialize(a::SectorArray)
3130 return ofsector (a, Base. Broadcast. materialize (a. data))
3231end
3332
34- function TensorAlgebra.:+ ₗ (a:: SectorArray , b:: SectorArray )
35- _check_add_axes (a, b)
36- return ofsector (a, a. data + ₗ b. data)
37- end
38-
39- function TensorAlgebra.:* ₗ (α:: Number , a:: SectorArray )
40- return ofsector (a, α * ₗ a. data)
41- end
42- TensorAlgebra.:* ₗ (a:: SectorArray , α:: Number ) = α * ₗ a
43- function TensorAlgebra. conjed (a:: SectorArray )
44- return ofsector (a, TensorAlgebra. conjed (a. data))
33+ function Base. similar (bc:: BC.Broadcasted{<:SectorStyle} , elt:: Type , ax)
34+ bc′ = BC. flatten (bc)
35+ arg = bc′. args[findfirst (arg -> arg isa SectorArray, bc′. args)]
36+ return ofsector (arg, similar (arg. data, elt))
4537end
4638
47- function BC. broadcasted (style:: SectorStyle , f, args... )
48- return TensorAlgebra. broadcasted_linear (style, f, args... )
39+ function Base. copyto! (dest:: SectorArray , bc:: BC.Broadcasted{<:SectorStyle} )
40+ lb = TensorAlgebra. tryflattenlinear (bc)
41+ isnothing (lb) &&
42+ throw (ArgumentError (" SectorArray broadcasting requires linear operations" ))
43+ return copyto! (dest, lb)
4944end
5045
5146struct GradedStyle{I, N, B <: BC.AbstractArrayStyle{N} } <: BC.AbstractArrayStyle{N}
@@ -90,67 +85,6 @@ function Base.similar(bc::BC.Broadcasted{<:GradedStyle}, elt::Type, ax)
9085 return graded_similar (arg, elt, ax)
9186end
9287
93- function _check_add_axes (a:: AbstractArray , b:: AbstractArray )
94- axes (a) == axes (b) ||
95- throw (
96- ArgumentError (" linear broadcasting requires matching axes" )
97- )
98- return nothing
99- end
100-
101- function lazyblock (a:: GradedArray{<:Any, N} , I:: Vararg{Block{1}, N} ) where {N}
102- if isstored (a, I... )
103- return blocks (a)[Int .(I)... ]
104- else
105- block_ax = map ((ax, i) -> eachblockaxis (ax)[Int (i)], axes (a), I)
106- return fillsimilar (Zeros {eltype(a)} (block_ax), block_ax)
107- end
108- end
109- lazyblock (a:: GradedArray , I:: Block ) = lazyblock (a, Tuple (I)... )
110-
111- TensorAlgebra. @scaledarray_type ScaledGradedArray
112- TensorAlgebra. @scaledarray ScaledGradedArray
113- TensorAlgebra. @conjarray_type ConjGradedArray
114- TensorAlgebra. @conjarray ConjGradedArray
115- TensorAlgebra. @addarray_type AddGradedArray
116- TensorAlgebra. @addarray AddGradedArray
117-
118- const LazyGradedArray = Union{
119- GradedArray, ScaledGradedArray, ConjGradedArray, AddGradedArray,
120- }
121-
122- function TensorAlgebra. BroadcastStyle_scaled (arrayt:: Type{<:ScaledGradedArray} )
123- return BC. BroadcastStyle (TensorAlgebra. unscaled_type (arrayt))
124- end
125- function TensorAlgebra. BroadcastStyle_conj (arrayt:: Type{<:ConjGradedArray} )
126- return BC. BroadcastStyle (TensorAlgebra. conjed_type (arrayt))
127- end
128- function TensorAlgebra. BroadcastStyle_add (arrayt:: Type{<:AddGradedArray} )
129- args_type = TensorAlgebra. addends_type (arrayt)
130- return Base. promote_op (BC. combine_styles, fieldtypes (args_type)... )()
131- end
132-
133- function lazyblock (a:: ScaledGradedArray , I:: Block )
134- return TensorAlgebra. coeff (a) * ₗ lazyblock (TensorAlgebra. unscaled (a), I)
135- end
136- function lazyblock (a:: ConjGradedArray , I:: Block )
137- return conjed (lazyblock (conjed (a), I))
138- end
139- function lazyblock (a:: AddGradedArray , I:: Block )
140- return + ₗ (map (Base. Fix2 (lazyblock, I), TensorAlgebra. addends (a))... )
141- end
142-
143- # TODO : Use `eachblockstoredindex` directly for lazy graded wrappers and delete the
144- # `graded_eachblockstoredindex` helper once that refactor is split into its own PR.
145- graded_eachblockstoredindex (a:: GradedArray ) = collect (eachblockstoredindex (a))
146- function graded_eachblockstoredindex (a:: ScaledGradedArray )
147- return graded_eachblockstoredindex (TensorAlgebra. unscaled (a))
148- end
149- graded_eachblockstoredindex (a:: ConjGradedArray ) = graded_eachblockstoredindex (conjed (a))
150- function graded_eachblockstoredindex (a:: AddGradedArray )
151- return unique! (vcat (map (graded_eachblockstoredindex, TensorAlgebra. addends (a))... ))
152- end
153-
15488# TODO : Rename `graded_similar` to `similar_graded` or fold it into `similar`
15589# entirely once the follow-up allocator cleanup is ready.
15690function graded_similar (
@@ -160,52 +94,10 @@ function graded_similar(
16094 ) where {N}
16195 return similar (a, elt, ax)
16296end
163- function graded_similar (
164- a:: ScaledGradedArray ,
165- elt:: Type ,
166- ax:: NTuple{N, <:GradedUnitRange}
167- ) where {N}
168- return graded_similar (TensorAlgebra. unscaled (a), elt, ax)
169- end
170- function graded_similar (
171- a:: ConjGradedArray ,
172- elt:: Type ,
173- ax:: NTuple{N, <:GradedUnitRange}
174- ) where {N}
175- return graded_similar (conjed (a), elt, ax)
176- end
177- function graded_similar (
178- a:: AddGradedArray ,
179- elt:: Type ,
180- ax:: NTuple{N, <:GradedUnitRange}
181- ) where {N}
182- style = BC. combine_styles (TensorAlgebra. addends (a)... )
183- bc = BC. Broadcasted (style, + , TensorAlgebra. addends (a))
184- return similar (bc, elt, ax)
185- end
186-
187- function copy_lazygraded (a:: LazyGradedArray )
188- c = graded_similar (a, eltype (a), axes (a))
189- for I in graded_eachblockstoredindex (a)
190- c[I] = lazyblock (a, I)
191- end
192- return c
193- end
194-
195- function TensorAlgebra.:+ ₗ (a:: LazyGradedArray , b:: LazyGradedArray )
196- _check_add_axes (a, b)
197- return AddGradedArray (a, b)
198- end
199- TensorAlgebra.:* ₗ (α:: Number , a:: GradedArray ) = ScaledGradedArray (α, a)
200- TensorAlgebra. conjed (a:: GradedArray ) = ConjGradedArray (a)
201-
202- Base. copy (a:: ScaledGradedArray ) = copy_lazygraded (a)
203- Base. copy (a:: ConjGradedArray ) = copy_lazygraded (a)
204- Base. copy (a:: AddGradedArray ) = copy_lazygraded (a)
205- Base. Array (a:: ScaledGradedArray ) = Array (copy (a))
206- Base. Array (a:: ConjGradedArray ) = Array (copy (a))
207- Base. Array (a:: AddGradedArray ) = Array (copy (a))
20897
209- function BC. broadcasted (style:: GradedStyle , f, args... )
210- return TensorAlgebra. broadcasted_linear (style, f, args... )
98+ function Base. copyto! (dest:: GradedArray , bc:: BC.Broadcasted{<:GradedStyle} )
99+ lb = TensorAlgebra. tryflattenlinear (bc)
100+ isnothing (lb) &&
101+ throw (ArgumentError (" GradedArray broadcasting requires linear operations" ))
102+ return copyto! (dest, lb)
211103end
0 commit comments