Skip to content

Commit a5f1eba

Browse files
Merge pull request #205 from vpuri3/isconcrete
concretize methods
2 parents 303622a + 03d6d39 commit a5f1eba

File tree

8 files changed

+105
-11
lines changed

8 files changed

+105
-11
lines changed

src/SciMLOperators.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ export update_coefficients!,
9898

9999
issquare,
100100
islinear,
101+
concretize,
102+
isconvertible,
101103

102104
has_adjoint,
103105
has_expmv,

src/basic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ end
761761

762762
getops(L::InvertedOperator) = (L.L,)
763763
islinear(L::InvertedOperator) = islinear(L.L)
764+
isconvertible(::InvertedOperator) = false
764765

765766
has_mul(L::InvertedOperator) = has_ldiv(L.L)
766767
has_mul!(L::InvertedOperator) = has_ldiv!(L.L)

src/batch.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
6767
)
6868
end
6969

70+
function Base.convert(::Type{AbstractMatrix}, L::BatchedDiagonalOperator)
71+
m, n = size(L)
72+
msg = """$L cannot be represented by an $m × $n AbstractMatrix"""
73+
throw(ArgumentError(msg))
74+
end
75+
7076
LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
7177
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
7278
if isreal(L)
@@ -91,6 +97,7 @@ function isconstant(L::BatchedDiagonalOperator)
9197
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
9298
end
9399
islinear(::BatchedDiagonalOperator) = true
100+
isconvertible(::BatchedDiagonalOperator) = false
94101
has_adjoint(L::BatchedDiagonalOperator) = true
95102
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
96103
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)

src/func.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
111111
* `has_mul5` - `true` if the operator provides a five-argument `mul!` via the signature `op(v, u, p, t, α, β; <accepted_kwargs>)`. This trait is inferred if no value is provided.
112112
* `isconstant` - `true` if the operator is constant, and doesn't need to be updated via `update_coefficients[!]` during operator evaluation.
113113
* `islinear` - `true` if the operator is linear. Defaults to `false`.
114+
* `isconvertible` - `true` a cheap `convert(AbstractMatrix, L.op)` method is available. Defaults to `false`.
114115
* `batch` - Boolean indicating if the input/output arrays comprise of batched column-vectors stacked in a matrix. If `true`, the input/output arrays must be `AbstractVecOrMat`s, and the length of the second dimension (the batch dimension) must be the same. The batch dimension is not involved in size computation. For example, with `batch = true`, and `size(output), size(input) = (M, K), (N, K)`, the `FunctionOperator` size is set to `(M, N)`. If `batch = false`, which is the default, the `input`/`output` arrays may of any size so long as `ndims(input) == ndims(output)`, and the `size` of `FunctionOperator` is set to `(length(input), length(output))`.
115116
* `ifcache` - Allocate cache arrays in constructor. Defaults to `true`. Cache can be generated afterwards by calling `cache_operator(L, input, output)`
116117
* `cache` - Pregenerated cache arrays for in-place evaluations. Expected to be of type and shape `(similar(input), similar(output),)`. The constructor generates cache if no values are provided. Cache generation by the constructor can be disabled by setting the kwarg `ifcache = false`.
@@ -138,6 +139,7 @@ function FunctionOperator(op,
138139
has_mul5::Union{Nothing,Bool}=nothing,
139140
isconstant::Bool = false,
140141
islinear::Bool = false,
142+
isconvertible::Bool = false,
141143

142144
batch::Bool = false,
143145
ifcache::Bool = true,
@@ -248,6 +250,7 @@ function FunctionOperator(op,
248250

249251
traits = (;
250252
islinear = islinear,
253+
isconvertible = isconvertible,
251254
isconstant = isconstant,
252255

253256
opnorm = opnorm,
@@ -480,6 +483,8 @@ function Base.inv(L::FunctionOperator)
480483
)
481484
end
482485

486+
Base.convert(::Type{AbstractMatrix}, L::FunctionOperator) = convert(AbstractMatrix, L.op)
487+
483488
function Base.resize!(L::FunctionOperator, n::Integer)
484489

485490
# input/output to `L` must be `AbstractVector`s
@@ -526,6 +531,7 @@ function getops(L::FunctionOperator)
526531
end
527532

528533
islinear(L::FunctionOperator) = L.traits.islinear
534+
isconvertible(L::FunctionOperator) = L.traits.isconvertible
529535
isconstant(L::FunctionOperator) = L.traits.isconstant
530536
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
531537
has_mul(::FunctionOperator{iip}) where{iip} = true

src/interface.jl

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,37 +187,53 @@ Base.oneunit(LType::Type{<:AbstractSciMLOperator}) = one(LType)
187187
Base.iszero(::AbstractSciMLOperator) = false # TODO
188188

189189
"""
190+
$SIGNATURES
191+
190192
Check if `adjoint(L)` is lazily defined.
191193
"""
192194
has_adjoint(L::AbstractSciMLOperator) = false # L', adjoint(L)
193195
"""
196+
$SIGNATURES
197+
194198
Check if `expmv!(v, L, u, t)`, equivalent to `mul!(v, exp(t * A), u)`, is
195199
defined for `Number` `t`, and `AbstractArray`s `u, v` of appropriate sizes.
196200
"""
197201
has_expmv!(L::AbstractSciMLOperator) = false # expmv!(v, L, t, u)
198202
"""
203+
$SIGNATURES
204+
199205
Check if `expmv(L, u, t)`, equivalent to `exp(t * A) * u`, is defined for
200206
`Number` `t`, and `AbstractArray` `u` of appropriate size.
201207
"""
202208
has_expmv(L::AbstractSciMLOperator) = false # v = exp(L, t, u)
203209
"""
210+
$SIGNATURES
211+
204212
Check if `exp(L)` is defined lazily defined.
205213
"""
206214
has_exp(L::AbstractSciMLOperator) = islinear(L)
207215
"""
216+
$SIGNATURES
217+
208218
Check if `L * u` is defined for `AbstractArray` `u` of appropriate size.
209219
"""
210220
has_mul(L::AbstractSciMLOperator) = true # du = L*u
211221
"""
222+
$SIGNATURES
223+
212224
Check if `mul!(v, L, u)` is defined for `AbstractArray`s `u, v` of
213225
appropriate sizes.
214226
"""
215227
has_mul!(L::AbstractSciMLOperator) = true # mul!(du, L, u)
216228
"""
229+
$SIGNATURES
230+
217231
Check if `L \\ u` is defined for `AbstractArray` `u` of appropriate size.
218232
"""
219233
has_ldiv(L::AbstractSciMLOperator) = false # du = L\u
220234
"""
235+
$SIGNATURES
236+
221237
Check if `ldiv!(v, L, u)` is defined for `AbstractArray`s `u, v` of
222238
appropriate sizes.
223239
"""
@@ -244,7 +260,57 @@ isconstant(::Union{
244260
) = true
245261
isconstant(L::AbstractSciMLOperator) = all(isconstant, getops(L))
246262

247-
#islinear(L) = false
263+
"""
264+
isconvertible(L) -> Bool
265+
266+
Checks if `L` can be cheaply converted to an `AbstractMatrix` via eager fusion.
267+
"""
268+
isconvertible(L::AbstractSciMLOperator) = all(isconvertible, getops(L))
269+
270+
isconvertible(::Union{
271+
# LinearAlgebra
272+
AbstractMatrix,
273+
UniformScaling,
274+
Factorization,
275+
276+
# Base
277+
Number,
278+
279+
# SciMLOperators
280+
AbstractSciMLScalarOperator,
281+
}
282+
) = true
283+
284+
"""
285+
concretize(L) -> AbstractMatrix
286+
287+
concretize(L) -> Number
288+
289+
Convert `SciMLOperator` to a concrete type via eager fusion. This method is a
290+
no-op for types that are already concrete.
291+
"""
292+
concretize(L::Union{
293+
# LinearAlgebra
294+
AbstractMatrix,
295+
Factorization,
296+
297+
# SciMLOperators
298+
AbstractSciMLOperator,
299+
}
300+
) = convert(AbstractMatrix, L)
301+
302+
concretize(L::Union{
303+
# LinearAlgebra
304+
UniformScaling,
305+
306+
# Base
307+
Number,
308+
309+
# SciMLOperators
310+
AbstractSciMLScalarOperator,
311+
}
312+
) = convert(Number, L)
313+
248314
"""
249315
$SIGNATURES
250316
@@ -349,22 +415,22 @@ expmv!(v,L::AbstractSciMLOperator,u,p,t) = mul!(v,exp(L,t),u)
349415
function Base.conj(L::AbstractSciMLOperator)
350416
isreal(L) && return L
351417
@warn """using convert-based fallback for Base.conj"""
352-
convert(AbstractMatrix, L) |> conj
418+
concretize(L) |> conj
353419
end
354420

355421
function Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator)
356422
@warn """using convert-based fallback for Base.=="""
357423
size(L1) != size(L2) && return false
358-
convert(AbstractMatrix, L1) == convert(AbstractMatrix, L1)
424+
concretize(L1) == concretize(L2)
359425
end
360426

361427
Base.@propagate_inbounds function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Any,N}) where {N}
362428
@warn """using convert-based fallback for Base.getindex"""
363-
convert(AbstractMatrix, L)[I...]
429+
concretize(L)[I...]
364430
end
365431
function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Int, N}) where {N}
366432
@warn """using convert-based fallback for Base.getindex"""
367-
convert(AbstractMatrix, L)[I...]
433+
concretize(L)[I...]
368434
end
369435

370436
function Base.resize!(L::AbstractSciMLOperator, n::Integer)
@@ -375,15 +441,15 @@ LinearAlgebra.exp(L::AbstractSciMLOperator) = exp(Matrix(L))
375441

376442
function LinearAlgebra.opnorm(L::AbstractSciMLOperator, p::Real=2)
377443
@warn """using convert-based fallback in LinearAlgebra.opnorm."""
378-
opnorm(convert(AbstractMatrix, L), p)
444+
opnorm(concretize(L), p)
379445
end
380446

381447
for op in (
382448
:sum, :prod,
383449
)
384450
@eval function Base.$op(L::AbstractSciMLOperator; kwargs...)
385451
@warn """using convert-based fallback in $($op)."""
386-
$op(convert(AbstractMatrix, L); kwargs...)
452+
$op(concretize(L); kwargs...)
387453
end
388454
end
389455

@@ -394,17 +460,17 @@ for pred in (
394460
)
395461
@eval function LinearAlgebra.$pred(L::AbstractSciMLOperator)
396462
@warn """using convert-based fallback in $($pred)."""
397-
$pred(convert(AbstractMatrix, L))
463+
$pred(concretize(L))
398464
end
399465
end
400466

401467
function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray)
402468
@warn """using convert-based fallback in mul!."""
403-
mul!(v, convert(AbstractMatrix, L), u)
469+
mul!(v, concretize(L), u)
404470
end
405471

406472
function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray, α, β)
407473
@warn """using convert-based fallback in mul!."""
408-
mul!(v, convert(AbstractMatrix, L), u, α, β)
474+
mul!(v, concretize(L), u, α, β)
409475
end
410476
#

src/matrix.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ end
103103
has_ldiv,
104104
has_ldiv!,
105105
)
106+
107+
isconvertible(::MatrixOperator) = true
106108
islinear(::MatrixOperator) = true
107109

108110
function Base.show(io::IO, L::MatrixOperator)
@@ -162,7 +164,7 @@ SparseArrays.issparse(L::MatrixOperator) = issparse(L.A)
162164

163165
# TODO - add tests for MatrixOperator indexing
164166
# propagate_inbounds here for the getindex fallback
165-
Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = L.A
167+
Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = convert(AbstractMatrix, L.A)
166168
Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, i::Int) = (L.A[i] = v)
167169
Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, I::Vararg{Int, N}) where{N} = (L.A[I...] = v)
168170

@@ -322,6 +324,7 @@ end
322324

323325
getops(L::InvertibleOperator) = (L.L, L.F,)
324326
islinear(L::InvertibleOperator) = islinear(L.L)
327+
isconvertible(L::InvertibleOperator) = isconvertible(L.L)
325328

326329
@forward InvertibleOperator.L (
327330
# LinearAlgebra
@@ -510,6 +513,7 @@ end
510513
getops(L::AffineOperator) = (L.A, L.B, L.b)
511514

512515
islinear(::AffineOperator) = false
516+
isconvertible(::AffineOperator) = false
513517

514518
function Base.show(io::IO, L::AffineOperator)
515519
show(io, L.A)
@@ -537,6 +541,12 @@ function Base.resize!(L::AffineOperator, n::Integer)
537541
L
538542
end
539543

544+
function Base.convert(::Type{AbstractMatrix}, L::AffineOperator)
545+
m, n = size(L)
546+
msg = """$L cannot be represented by an $m × $n AbstractMatrix"""
547+
throw(ArgumentError(msg))
548+
end
549+
540550
has_adjoint(L::AffineOperator) = false
541551
has_mul(L::AffineOperator) = has_mul(L.A)
542552
has_mul!(L::AffineOperator) = has_mul!(L.A)

src/scalar.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Base.adjoint(α::AbstractSciMLScalarOperator) = conj(α)
3232
Base.transpose::AbstractSciMLScalarOperator) = α
3333

3434
has_mul!(::AbstractSciMLScalarOperator) = true
35+
isconcrete(::AbstractSciMLScalarOperator) = true
3536
islinear(::AbstractSciMLScalarOperator) = true
3637
has_adjoint(::AbstractSciMLScalarOperator) = true
3738

src/tensor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ end
121121

122122
getops(L::TensorProductOperator) = L.ops
123123
islinear(L::TensorProductOperator) = reduce(&, islinear.(L.ops))
124+
isconvertible(::TensorProductOperator) = false
124125
Base.iszero(L::TensorProductOperator) = reduce(|, iszero.(L.ops))
125126
has_adjoint(L::TensorProductOperator) = reduce(&, has_adjoint.(L.ops))
126127
has_mul(L::TensorProductOperator) = reduce(&, has_mul.(L.ops))

0 commit comments

Comments
 (0)