Skip to content

Commit 9dc5f38

Browse files
committed
Expand functions in P'(x .^2 .* P)
1 parent 61991d4 commit 9dc5f38

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

src/bases/bases.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,46 @@ function _broadcast_mul_ldiv(::Tuple{ScalarLayout,ApplyLayout{typeof(*)}}, A, B)
145145
a * (A \ b)
146146
end
147147

148-
_broadcast_mul_ldiv(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) =
149-
_broadcast_mul_ldiv((ScalarLayout(),UnknownLayout()), A, B)
148+
_broadcast_mul_ldiv(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) = _broadcast_mul_ldiv((ScalarLayout(),UnknownLayout()), A, B)
150149
_broadcast_mul_ldiv(_, A, B) = copy(Ldiv{typeof(MemoryLayout(A)),UnknownLayout}(A,B))
151150

152151
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
153152
copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
154153

155154

156155

156+
# multiplication operators, reexpand in basis A
157+
@inline function _broadcast_mul_adj(::Tuple{Any,AbstractBasisLayout}, Ac, B)
158+
a,b = arguments(B)
159+
@assert a isa AbstractQuasiVector # Only works for vec .* mat
160+
A = Ac'
161+
ab = (A * (A \ a)) .* b # broadcasted should be overloaded
162+
MemoryLayout(ab) isa BroadcastLayout && return Ac*transform_ldiv(A, ab)
163+
Ac*ab
164+
end
165+
166+
@inline function _broadcast_mul_adj(::Tuple{Any,ApplyLayout{typeof(*)}}, Ac, B)
167+
a,b = arguments(B)
168+
@assert a isa AbstractQuasiVector # Only works for vec .* mat
169+
args = arguments(*, b)
170+
*(Ac*(a .* first(args)), tail(args)...)
171+
end
172+
173+
174+
function _broadcast_mul_adj(::Tuple{ScalarLayout,Any}, Ac, B)
175+
a,b = arguments(B)
176+
a * (Ac*b)
177+
end
178+
179+
function _broadcast_mul_adj(::Tuple{ScalarLayout,ApplyLayout{typeof(*)}}, Ac, B)
180+
a,b = arguments(B)
181+
a * (Ac*b)
182+
end
183+
184+
_broadcast_mul_adj(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) = _broadcast_mul_adj((ScalarLayout(),UnknownLayout()), A, B)
185+
_broadcast_mul_adj(_, A, B) = copy(Mul{typeof(MemoryLayout(A)),UnknownLayout}(A,B))
186+
187+
copy(L::Mul{<:AdjointBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_adj(map(MemoryLayout,arguments(L.B)), L.A, L.B)
157188

158189

159190
"""

0 commit comments

Comments
 (0)