Skip to content

Commit 11447bf

Browse files
authored
Merge pull request #183 from vpuri3/scalar
fixed scalaroperator convert methods.
2 parents 282f624 + 74f829b commit 11447bf

File tree

3 files changed

+42
-18
lines changed

3 files changed

+42
-18
lines changed

src/interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ issquare(::Union{
226226
) = true
227227
issquare(A...) = @. (&)(issquare(A)...)
228228

229+
Base.length(L::AbstractSciMLOperator) = prod(size(L))
229230
Base.ndims(L::AbstractSciMLOperator) = length(size(L))
230231
Base.isreal(L::AbstractSciMLOperator{T}) where{T} = T <: Real
231232
Base.Matrix(L::AbstractSciMLOperator) = Matrix(convert(AbstractMatrix, L))

src/scalar.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ SCALINGCOMBINETYPES = (
2121
:IdentityOperator
2222
)
2323

24+
Base.length(::AbstractSciMLScalarOperator) = 1
2425
Base.size::AbstractSciMLScalarOperator) = ()
2526
Base.adjoint::AbstractSciMLScalarOperator) = conj(α)
2627
Base.transpose::AbstractSciMLScalarOperator) = α
@@ -110,7 +111,7 @@ function ScalarOperator(val::T; update_func=DEFAULT_UPDATE_FUNC) where{T}
110111
end
111112

112113
# constructors
113-
Base.convert(::Type{Number}, α::ScalarOperator) = α.val
114+
Base.convert(T::Type{<:Number}, α::ScalarOperator) = convert(T, α.val)
114115
Base.convert(::Type{ScalarOperator}, α::Number) = ScalarOperator(α)
115116

116117
ScalarOperator::AbstractSciMLScalarOperator) = α
@@ -173,8 +174,8 @@ for op in (
173174
end
174175
end
175176

176-
function Base.convert(::Type{Number}, α::AddedScalarOperator{T}) where{T}
177-
sum(op -> convert(Number, op), α.ops)
177+
function Base.convert(T::Type{<:Number}, α::AddedScalarOperator)
178+
sum(convert.(T, α.ops))
178179
end
179180

180181
Base.conj(L::AddedScalarOperator) = AddedScalarOperator(conj.(L.ops))
@@ -222,9 +223,9 @@ for op in (
222223
end
223224
end
224225

225-
function Base.convert(::Type{Number}, α::ComposedScalarOperator{T}) where{T}
226+
function Base.convert(T::Type{<:Number}, α::ComposedScalarOperator)
226227
iszero(α) && return zero(T)
227-
prod( op -> convert(Number, op), α.ops; init=one(T))
228+
prod(convert.(T, α.ops))
228229
end
229230

230231
Base.conj(L::ComposedScalarOperator) = ComposedScalarOperator(conj.(L.ops))
@@ -279,8 +280,8 @@ for op in (
279280
@eval Base.$op::AbstractSciMLScalarOperator, β::AbstractSciMLScalarOperator) = inv(α) * β
280281
end
281282

282-
function Base.convert(::Type{Number}, α::InvertedScalarOperator{T}) where{T}
283-
return inv(convert(Number, α.λ))
283+
function Base.convert(T::Type{<:Number}, α::InvertedScalarOperator)
284+
inv(convert(Number, α.λ))
284285
end
285286

286287
Base.conj(L::InvertedScalarOperator) = InvertedScalarOperator(conj(L.λ))

test/scalar.jl

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ K = 12
1818
@test issquare(α)
1919
@test islinear(α)
2020

21-
@test convert(Number, α) isa Number
21+
@test convert(Float32, α) isa Float32
2222
@test convert(ScalarOperator, a) isa ScalarOperator
2323

2424
@test size(α) == ()
@@ -37,16 +37,35 @@ K = 12
3737
@test axpy!(aa,X,Y) a*X+Z
3838

3939
# Test that ScalarOperator's remain AbstractSciMLScalarOperator's under common ops
40-
@test α + α isa SciMLOperators.AddedScalarOperator
41-
+ α) * u x * u + x * u
42-
@test α * α isa SciMLOperators.ComposedScalarOperator
43-
* α) * u x * x * u
44-
@test inv(α) isa SciMLOperators.InvertedScalarOperator
45-
inv(α) * u 1/x * u
46-
@test α * inv(α) isa SciMLOperators.ComposedScalarOperator
47-
α * inv(α) * u u
48-
@test α / α isa SciMLOperators.ComposedScalarOperator
49-
α * α * u u
40+
β = α + α
41+
@test β isa SciMLOperators.AddedScalarOperator
42+
@test β * u x * u + x * u
43+
@inferred convert(Float32, β)
44+
@test convert(Number, β) x + x
45+
46+
β = α * α
47+
@test β isa SciMLOperators.ComposedScalarOperator
48+
@test β * u x * x * u
49+
@inferred convert(Float32, β)
50+
@test convert(Number, β) x * x
51+
52+
β = inv(α)
53+
@test β isa SciMLOperators.InvertedScalarOperator
54+
@test β * u 1 / x * u
55+
@inferred convert(Float32, β)
56+
@test convert(Number, β) 1 / x
57+
58+
β = α * inv(α)
59+
@test β isa SciMLOperators.ComposedScalarOperator
60+
@test β * u u
61+
@inferred convert(Float32, β)
62+
@test convert(Number, β) true
63+
64+
β = α / α
65+
@test β isa SciMLOperators.ComposedScalarOperator
66+
@test β * u u
67+
@inferred convert(Float32, β)
68+
@test convert(Number, β) true
5069

5170
# Test combination with other operators
5271
for op in (MatrixOperator(rand(N, N)), SciMLOperators.IdentityOperator(N))
@@ -74,6 +93,9 @@ end
7493
@test !isconstant(α)
7594
@test !isconstant(β)
7695

96+
@test convert(Float32, α) isa Float32
97+
@test convert(Float32, β) isa Float32
98+
7799
@test convert(Number, α) 0.0
78100
@test convert(Number, β) 0.0
79101

0 commit comments

Comments
 (0)