@@ -8,26 +8,32 @@ by `update_coefficients!` and is assumed to have the following signature:
88
99 update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
1010"""
11- struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T}
11+ struct BatchedDiagonalOperator{T,D,F,F! } <: AbstractSciMLOperator{T}
1212 diag:: D
1313 update_func:: F
14+ update_func!:: F!
1415
15- function BatchedDiagonalOperator (
16- diag:: AbstractArray ;
17- update_func= DEFAULT_UPDATE_FUNC
18- )
16+ function BatchedDiagonalOperator (diag:: AbstractArray , update_func, update_func!)
1917 new{
2018 eltype (diag),
2119 typeof (diag),
22- typeof (update_func)
20+ typeof (update_func),
21+ typeof (update_func!),
2322 }(
24- diag, update_func,
23+ diag, update_func, update_func!,
2524 )
2625 end
2726end
2827
29- function DiagonalOperator (u:: AbstractArray ; update_func= DEFAULT_UPDATE_FUNC)
30- BatchedDiagonalOperator (u; update_func= update_func)
28+ function BatchedDiagonalOperator (diag:: AbstractArray ;
29+ update_func = DEFAULT_UPDATE_FUNC,
30+ update_func! = DEFAULT_UPDATE_FUNC)
31+ BatchedDiagonalOperator (diag, update_func, update_func!)
32+ end
33+
34+ function DiagonalOperator (u:: AbstractArray ; update_func = DEFAULT_UPDATE_FUNC,
35+ update_func! = DEFAULT_UPDATE_FUNC)
36+ BatchedDiagonalOperator (u; update_func = update_func, update_func! = update_func!)
3137end
3238
3339# traits
@@ -45,6 +51,21 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
4551 BatchedDiagonalOperator (diag; update_func= update_func)
4652end
4753
54+ function update_coefficients (L:: BatchedDiagonalOperator ,u,p,t)
55+ @set! L. diag = L. update_func (L. diag,u,p,t)
56+ end
57+ update_coefficients! (L:: BatchedDiagonalOperator ,u,p,t) = (L. update_func! (L. diag,u,p,t); L)
58+
59+ getops (L:: BatchedDiagonalOperator ) = (L. diag,)
60+
61+ function isconstant (L:: BatchedDiagonalOperator )
62+ L. update_func == L. update_func! == DEFAULT_UPDATE_FUNC
63+ end
64+ islinear (:: BatchedDiagonalOperator ) = true
65+ has_adjoint (L:: BatchedDiagonalOperator ) = true
66+ has_ldiv (L:: BatchedDiagonalOperator ) = all (x -> ! iszero (x), L. diag)
67+ has_ldiv! (L:: BatchedDiagonalOperator ) = has_ldiv (L)
68+
4869LinearAlgebra. issymmetric (L:: BatchedDiagonalOperator ) = true
4970function LinearAlgebra. ishermitian (L:: BatchedDiagonalOperator )
5071 if isreal (L)
@@ -57,16 +78,6 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
5778end
5879LinearAlgebra. isposdef (L:: BatchedDiagonalOperator ) = isposdef (Diagonal (vec (L. diag)))
5980
60- isconstant (L:: BatchedDiagonalOperator ) = L. update_func == DEFAULT_UPDATE_FUNC
61- islinear (:: BatchedDiagonalOperator ) = true
62- has_adjoint (L:: BatchedDiagonalOperator ) = true
63- has_ldiv (L:: BatchedDiagonalOperator ) = all (x -> ! iszero (x), L. diag)
64- has_ldiv! (L:: BatchedDiagonalOperator ) = has_ldiv (L)
65-
66- getops (L:: BatchedDiagonalOperator ) = (L. diag,)
67-
68- update_coefficients! (L:: BatchedDiagonalOperator ,u,p,t) = (L. update_func (L. diag,u,p,t); nothing )
69-
7081# operator application
7182Base.:* (L:: BatchedDiagonalOperator , u:: AbstractVecOrMat ) = L. diag .* u
7283Base.:\ (L:: BatchedDiagonalOperator , u:: AbstractVecOrMat ) = L. diag .\ u
0 commit comments