Skip to content

Commit d728a14

Browse files
Merge pull request #136 from vpuri3/ambig
Fix method ambiguities, add tests
2 parents f0eb5c0 + 58cb04c commit d728a14

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

src/scalar.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ for op in (
209209

210210
for T in SCALINGNUMBERTYPES[2:end]
211211
@eval Base.$op::AbstractSciMLScalarOperator, x::$T) = ComposedScalarOperator(α, ScalarOperator(x))
212-
@eval Base.$op(x::$T, α::AbstractSciMLScalarOperator) = ComposedScalarOperator(ScalarOperator(x), x)
212+
@eval Base.$op(x::$T, α::AbstractSciMLScalarOperator) = ComposedScalarOperator(ScalarOperator(x), α)
213213
end
214214
end
215215

@@ -247,8 +247,9 @@ for op in (
247247
for T in SCALINGNUMBERTYPES[2:end]
248248
@eval Base.$op::AbstractSciMLScalarOperator, x::$T) = α * inv(ScalarOperator(x))
249249
@eval Base.$op(x::$T, α::AbstractSciMLScalarOperator) = ScalarOperator(x) * inv(α)
250-
@eval Base.$op::AbstractSciMLScalarOperator, β::AbstractSciMLScalarOperator) = α * inv(β)
251250
end
251+
252+
@eval Base.$op::AbstractSciMLScalarOperator, β::AbstractSciMLScalarOperator) = α * inv(β)
252253
end
253254

254255
for op in (
@@ -257,8 +258,9 @@ for op in (
257258
for T in SCALINGNUMBERTYPES[2:end]
258259
@eval Base.$op::AbstractSciMLScalarOperator, x::$T) = inv(α) * ScalarOperator(x)
259260
@eval Base.$op(x::$T, α::AbstractSciMLScalarOperator) = inv(ScalarOperator(x)) * α
260-
@eval Base.$op::AbstractSciMLScalarOperator, β::AbstractSciMLScalarOperator) = inv(α) * β
261261
end
262+
263+
@eval Base.$op::AbstractSciMLScalarOperator, β::AbstractSciMLScalarOperator) = inv(α) * β
262264
end
263265

264266
function Base.convert(::Type{Number}, α::InvertedScalarOperator{T}) where{T}

src/tensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function cache_self(L::TensorProductOperator, u::AbstractVecOrMat)
147147
L
148148
end
149149

150-
function cache_internals(L::TensorProductOperator, u::AbstractVecOrMat) where{D}
150+
function cache_internals(L::TensorProductOperator, u::AbstractVecOrMat)
151151
if !(L.isset)
152152
L = cache_self(L, u)
153153
end

test/scalar.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,23 @@ K = 12
5656
end
5757

5858
@testset "ScalarOperator update test" begin
59-
u = rand(N,K)
60-
p = rand(N)
61-
t = 0.0
59+
u = ones(N,K)
60+
v = zeros(N,K)
61+
p = rand()
62+
t = rand()
63+
64+
α = ScalarOperator(0.0; update_func=(a,u,p,t) -> p)
65+
β = ScalarOperator(0.0; update_func=(a,u,p,t) -> t)
66+
67+
@test α(u,p,t) p * u
68+
@test α(v,u,p,t) p * u
69+
70+
num = α + 2 / β * 3 - 4
71+
val = p + 2 / t * 3 - 4
6272

63-
α = ScalarOperator(zero(Float64);
64-
update_func=(a,u,p,t) -> sum(p)
65-
)
73+
@test num(u,p,t) val * u
74+
@test num(v,u,p,t) val * u
6675

67-
ans = sum(p) * u
68-
@test α(u,p,t) ans
69-
v=copy(u); @test α(v,u,p,t) ans
76+
@test convert(Number, num) val
7077
end
7178
#

0 commit comments

Comments
 (0)