Skip to content

Commit d865705

Browse files
committed
fix isinplace inference and add inference tests
1 parent 80e6ea5 commit d865705

File tree

5 files changed

+291
-489
lines changed

5 files changed

+291
-489
lines changed

src/problems/optimization_problems.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ function OptimizationProblem(
131131
OptimizationProblem{isinplace(f)}(f, args...; kwargs...)
132132
end
133133
function OptimizationProblem(f, args...; kwargs...)
134-
isinplace(f, 2, has_two_dispatches = false)
135-
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
134+
OptimizationProblem(OptimizationFunction(f), args...; kwargs...)
136135
end
137136

138137
function OptimizationFunction(

src/scimlfunctions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4199,7 +4199,7 @@ IntervalNonlinearFunction(f::IntervalNonlinearFunction; kwargs...) = f
41994199
struct NoAD <: AbstractADType end
42004200

42014201
(f::OptimizationFunction)(args...) = f.f(args...)
4202-
OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; kwargs...)
4202+
OptimizationFunction(f, args...; kwargs...) = OptimizationFunction{isinplace(f, 3)}(f, args...; kwargs...)
42034203

42044204
function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
42054205
grad = nothing, fg = nothing, hess = nothing, hv = nothing, fgh = nothing,
@@ -4251,8 +4251,8 @@ end
42514251
(f::MultiObjectiveOptimizationFunction)(args...) = f.f(args...)
42524252

42534253
# Convenience constructor
4254-
function MultiObjectiveOptimizationFunction(args...; kwargs...)
4255-
MultiObjectiveOptimizationFunction{true}(args...; kwargs...)
4254+
function MultiObjectiveOptimizationFunction(f, args...; kwargs...)
4255+
MultiObjectiveOptimizationFunction{isinplace(f, 3)}(f, args...; kwargs...)
42564256
end
42574257

42584258
# Constructor with keyword arguments

src/utils.jl

Lines changed: 12 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -42,157 +42,6 @@ function num_types_in_tuple(sig::UnionAll)
4242
length(Base.unwrap_unionall(sig).parameters)
4343
end
4444

45-
const NO_METHODS_ERROR_MESSAGE = """
46-
No methods were found for the model function passed to the equation solver.
47-
The function `f` needs to have dispatches, for example, for an ODEProblem
48-
`f` must define either `f(u,p,t)` or `f(du,u,p,t)`. For more information
49-
on how the model function `f` should be defined, consult the docstring for
50-
the appropriate `AbstractSciMLFunction`.
51-
"""
52-
53-
struct NoMethodsError <: Exception
54-
fname::String
55-
end
56-
57-
function Base.showerror(io::IO, e::NoMethodsError)
58-
println(io, NO_METHODS_ERROR_MESSAGE)
59-
print(io, "Offending function: ")
60-
printstyled(io, e.fname; bold = true, color = :red)
61-
end
62-
63-
const TOO_MANY_ARGUMENTS_ERROR_MESSAGE = """
64-
All methods for the model function `f` had too many arguments. For example,
65-
an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`. This error
66-
can be thrown if you define an ODE model for example as `f(du,u,p1,p2,t)`.
67-
For more information on the required number of arguments for the function
68-
you were defining, consult the documentation for the `SciMLProblem` or
69-
`SciMLFunction` type that was being constructed.
70-
71-
A common reason for this occurrence is due to following the MATLAB or SciPy
72-
convention for parameter passing, i.e. to add each parameter as an argument.
73-
In the SciML convention, if you wish to pass multiple parameters, use a
74-
struct or other collection to hold the parameters. For example, here is the
75-
parameterized Lorenz equation:
76-
77-
```julia
78-
function lorenz(du,u,p,t)
79-
du[1] = p[1]*(u[2]-u[1])
80-
du[2] = u[1]*(p[2]-u[3]) - u[2]
81-
du[3] = u[1]*u[2] - p[3]*u[3]
82-
end
83-
u0 = [1.0;0.0;0.0]
84-
p = [10.0,28.0,8/3]
85-
tspan = (0.0,100.0)
86-
prob = ODEProblem(lorenz,u0,tspan,p)
87-
```
88-
89-
Notice that `f` is defined with a single `p`, an array which matches the definition
90-
of the `p` in the `ODEProblem`. Note that `p` can be any Julia struct.
91-
"""
92-
93-
struct TooManyArgumentsError <: Exception
94-
fname::String
95-
f::Any
96-
end
97-
98-
function Base.showerror(io::IO, e::TooManyArgumentsError)
99-
println(io, TOO_MANY_ARGUMENTS_ERROR_MESSAGE)
100-
print(io, "Offending function: ")
101-
printstyled(io, e.fname; bold = true, color = :red)
102-
println(io, "\nMethods:")
103-
println(io, methods(e.f))
104-
end
105-
106-
const TOO_FEW_ARGUMENTS_ERROR_MESSAGE_OPTIMIZATION = """
107-
All methods for the model function `f` had too few arguments. For example,
108-
an OptimizationProblem `f` must define `f(u,p)` where `u` is the optimization
109-
state and `p` are the parameters of the optimization (commonly, the hyperparameters
110-
of the simulation).
111-
112-
A common reason for this error is from defining a single-input loss function
113-
`f(u)`. While parameters are not required, a loss function which takes parameters
114-
is required, i.e. `f(u,p)`. If you have a function `f(u)`, ignored parameters
115-
can be easily added using a closure, i.e. `OptimizationProblem((u,_)->f(u),...)`.
116-
117-
For example, here is a parameterized optimization problem:
118-
119-
```julia
120-
using Optimization, OptimizationOptimJL
121-
rosenbrock(u,p) = (p[1] - u[1])^2 + p[2] * (u[2] - u[1]^2)^2
122-
u0 = zeros(2)
123-
p = [1.0,100.0]
124-
125-
prob = OptimizationProblem(rosenbrock,u0,p)
126-
sol = solve(prob,NelderMead())
127-
```
128-
129-
and a parameter-less example:
130-
131-
```julia
132-
using Optimization, OptimizationOptimJL
133-
rosenbrock(u,p) = (1 - u[1])^2 + (u[2] - u[1]^2)^2
134-
u0 = zeros(2)
135-
136-
prob = OptimizationProblem(rosenbrock,u0)
137-
sol = solve(prob,NelderMead())
138-
```
139-
"""
140-
141-
const TOO_FEW_ARGUMENTS_ERROR_MESSAGE = """
142-
All methods for the model function `f` had too few arguments. For example,
143-
an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`. This error
144-
can be thrown if you define an ODE model for example as `f(u,t)`. The parameters
145-
`p` are not optional in the definition of `f`! For more information on the required
146-
number of arguments for the function you were defining, consult the documentation
147-
for the `SciMLProblem` or `SciMLFunction` type that was being constructed.
148-
149-
For example, here is the no parameter Lorenz equation. The two valid versions
150-
are out of place:
151-
152-
```julia
153-
function lorenz(u,p,t)
154-
du1 = 10.0*(u[2]-u[1])
155-
du2 = u[1]*(28.0-u[3]) - u[2]
156-
du3 = u[1]*u[2] - 8/3*u[3]
157-
[du1,du2,du3]
158-
end
159-
u0 = [1.0;0.0;0.0]
160-
tspan = (0.0,100.0)
161-
prob = ODEProblem(lorenz,u0,tspan)
162-
```
163-
164-
and in-place:
165-
166-
```julia
167-
function lorenz!(du,u,p,t)
168-
du[1] = 10.0*(u[2]-u[1])
169-
du[2] = u[1]*(28.0-u[3]) - u[2]
170-
du[3] = u[1]*u[2] - 8/3*u[3]
171-
end
172-
u0 = [1.0;0.0;0.0]
173-
tspan = (0.0,100.0)
174-
prob = ODEProblem(lorenz!,u0,tspan)
175-
```
176-
"""
177-
178-
struct TooFewArgumentsError <: Exception
179-
fname::String
180-
f::Any
181-
isoptimization::Bool
182-
end
183-
184-
function Base.showerror(io::IO, e::TooFewArgumentsError)
185-
if e.isoptimization
186-
println(io, TOO_FEW_ARGUMENTS_ERROR_MESSAGE_OPTIMIZATION)
187-
else
188-
println(io, TOO_FEW_ARGUMENTS_ERROR_MESSAGE)
189-
end
190-
print(io, "Offending function: ")
191-
printstyled(io, e.fname; bold = true, color = :red)
192-
println(io, "\nMethods:")
193-
println(io, methods(e.f))
194-
end
195-
19645
const ARGUMENTS_ERROR_MESSAGE = """
19746
Methods dispatches for the model function `f` do not match the required number.
19847
For example, an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`.
@@ -207,6 +56,12 @@ struct FunctionArgumentsError <: Exception
20756
f::Any
20857
end
20958

59+
# backward compat in case anyone is using these.
60+
# TODO: remove at next major version
61+
const TooManyArgumentsError = FunctionArgumentsError
62+
const TooFewArgumentsError = FunctionArgumentsError
63+
const NoMethodsError = FunctionArgumentsError
64+
21065
function Base.showerror(io::IO, e::FunctionArgumentsError)
21166
println(io, ARGUMENTS_ERROR_MESSAGE)
21267
print(io, "Offending function: ")
@@ -246,66 +101,14 @@ form is disabled and the 2-argument signature is ensured to be matched.
246101
function isinplace(f, inplace_param_number, fname = "f", iip_preferred = true;
247102
has_two_dispatches = true, isoptimization = false,
248103
outofplace_param_number = inplace_param_number - 1)
249-
nargs = numargs(f)
250-
iip_dispatch = any(x -> x == inplace_param_number, nargs)
251-
oop_dispatch = any(x -> x == outofplace_param_number, nargs)
252-
253-
if length(nargs) == 0
254-
throw(NoMethodsError(fname))
255-
end
256-
257-
if !iip_dispatch && !oop_dispatch && !isoptimization
258-
if all(>(inplace_param_number), nargs)
259-
throw(TooManyArgumentsError(fname, f))
260-
elseif all(<(outofplace_param_number), nargs) && has_two_dispatches
261-
# Possible extra safety?
262-
# Find if there's a `f(args...)` dispatch
263-
# If so, no error
264-
_parameters = if methods(f).ms[1].sig isa UnionAll
265-
Base.unwrap_unionall(methods(f).ms[1].sig).parameters
266-
else
267-
methods(f).ms[1].sig.parameters
268-
end
269-
270-
for i in 1:length(nargs)
271-
if nargs[i] < inplace_param_number &&
272-
any(isequal(Vararg{Any}), _parameters)
273-
# If varargs, assume iip
274-
return iip_preferred
275-
end
276-
end
277-
278-
# No varargs detected, error that there are dispatches but not the right ones
279-
280-
throw(TooFewArgumentsError(fname, f, isoptimization))
281-
else
282-
throw(FunctionArgumentsError(fname, f))
283-
end
284-
elseif oop_dispatch && !iip_dispatch && !has_two_dispatches
285-
286-
# Possible extra safety?
287-
# Find if there's a `f(args...)` dispatch
288-
# If so, no error
289-
for i in 1:length(nargs)
290-
if nargs[i] < inplace_param_number &&
291-
any(isequal(Vararg{Any}), methods(f).ms[1].sig.parameters)
292-
# If varargs, assume iip
293-
return iip_preferred
294-
end
295-
end
296-
297-
throw(TooFewArgumentsError(fname, f, isoptimization))
104+
if iip_preferred
105+
hasmethod(f, ntuple(_->Any, inplace_param_number)) && return true
106+
hasmethod(f, ntuple(_->Any, outofplace_param_number)) && return false
298107
else
299-
if iip_preferred
300-
# Equivalent to, if iip_dispatch exists, treat as iip
301-
# Otherwise, it's oop
302-
iip_dispatch
303-
else
304-
# Equivalent to, if oop_dispatch exists, treat as oop
305-
# Otherwise, it's iip
306-
!oop_dispatch
307-
end
108+
hasmethod(f, ntuple(_->Any, outofplace_param_number)) && return false
109+
hasmethod(f, ntuple(_->Any, inplace_param_number)) && return true
308110
end
111+
throw(FunctionArgumentsError(fname, f))
309112
end
310113

311114
isinplace(f::AbstractSciMLFunction{iip}) where {iip} = iip

test/aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
# for method_ambiguity in ambs
3030
# @show method_ambiguity
3131
# end
32-
@warn "Number of method ambiguities: $(length(ambs))"
32+
!isempty(ambs) &&@warn "Number of method ambiguities: $(length(ambs))"
3333
@test length(ambs) 8
3434
end
3535

0 commit comments

Comments
 (0)