Skip to content

Commit 041233b

Browse files
Merge pull request #203 from vpuri3/conj
fix conj overload for matrixop, batcheddiagonal
2 parents a1e7622 + 3c89279 commit 041233b

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

src/batch.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
BatchedDiagonalOperator(diag; update_func, update_func!, accepted_kwargs)
44
55
Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
6-
Acts on `AbstractArray`s of the same size as `diag`. The update function is called
7-
by `update_coefficients!` and is assumed to have the following signature:
6+
Acts on `AbstractArray`s of the same size as `diag`. The update function is
7+
called by `update_coefficients!` and is assumed to have the following signature:
88
99
update_func(diag::AbstractArray, u, p, t; <accepted kwarg fields>) -> [modifies diag]
1010
"""
@@ -48,13 +48,23 @@ Base.iszero(L::BatchedDiagonalOperator) = iszero(L.diag)
4848
Base.transpose(L::BatchedDiagonalOperator) = L
4949
Base.adjoint(L::BatchedDiagonalOperator) = conj(L)
5050
function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
51-
diag = conj(L.diag)
52-
update_func = if isreal(L)
53-
L.update_func
51+
52+
update_func, update_func! = if isreal(L)
53+
L.update_func, L.update_func!
5454
else
55-
(L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...))
55+
uf = (L, u, p, t; kwargs...) -> conj(L.update_func(conj(L.diag), u, p, t; kwargs...))
56+
uf! = (L, u, p, t; kwargs...) -> begin
57+
L.update_func(conj!(L.diag), u, p, t; kwargs...)
58+
conj!(L.diag)
59+
end
60+
uf, uf!
5661
end
57-
BatchedDiagonalOperator(diag; update_func=update_func)
62+
63+
DiagonalOperator(conj(L.diag);
64+
update_func = update_func,
65+
update_func! = update_func!,
66+
accepted_kwargs = NoKwargFilter(),
67+
)
5868
end
5969

6070
LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true

src/matrix.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ function Base.conj(L::MatrixOperator)
133133
isconstant(L) && return MatrixOperator(conj(L.A))
134134

135135
update_func = (A, u, p, t; kwargs...) -> conj(L.update_func(conj(L.A), u, p, t; kwargs...))
136-
update_func! = (A, u, p, t; kwargs...) -> conj(L.update_func!(conj(L.A), u, p, t; kwargs...))
136+
update_func! = (A, u, p, t; kwargs...) -> begin
137+
L.update_func!(conj!(L.A), u, p, t; kwargs...)
138+
conj!(L.A)
139+
end
137140

138141
MatrixOperator(conj(L.A);
139142
update_func = update_func,

0 commit comments

Comments
 (0)