1
+ """
2
+ @add_kwonly function_definition
3
+
4
+ Define keyword-only version of the `function_definition`.
5
+
6
+ @add_kwonly function f(x; y=1)
7
+ ...
8
+ end
9
+
10
+ expands to:
11
+
12
+ function f(x; y=1)
13
+ ...
14
+ end
15
+ function f(; x = error("No argument x"), y=1)
16
+ ...
17
+ end
18
+ """
19
+ macro add_kwonly (ex)
20
+ esc (add_kwonly (ex))
21
+ end
22
+
23
+ add_kwonly (ex:: Expr ) = add_kwonly (Val{ex. head}, ex)
24
+
25
+ function add_kwonly (:: Type{<: Val} , ex)
26
+ error (" add_only does not work with expression $(ex. head) " )
27
+ end
28
+
29
+ function add_kwonly (:: Union {Type{Val{:function }},
30
+ Type{Val{:(= )}}}, ex:: Expr )
31
+ body = ex. args[2 : end ] # function body
32
+ default_call = ex. args[1 ] # e.g., :(f(a, b=2; c=3))
33
+ kwonly_call = add_kwonly (default_call)
34
+ if kwonly_call === nothing
35
+ return ex
36
+ end
37
+
38
+ return quote
39
+ begin
40
+ $ ex
41
+ $ (Expr (ex. head, kwonly_call, body... ))
42
+ end
43
+ end
44
+ end
45
+
46
+ function add_kwonly (:: Type{Val{:where}} , ex:: Expr )
47
+ default_call = ex. args[1 ]
48
+ rest = ex. args[2 : end ]
49
+ kwonly_call = add_kwonly (default_call)
50
+ if kwonly_call === nothing
51
+ return nothing
52
+ end
53
+ return Expr (:where , kwonly_call, rest... )
54
+ end
55
+
56
+ function add_kwonly (:: Type{Val{:call}} , default_call:: Expr )
57
+ # default_call is, e.g., :(f(a, b=2; c=3))
58
+ funcname = default_call. args[1 ] # e.g., :f
59
+ required = [] # required positional arguments; e.g., [:a]
60
+ optional = [] # optional positional arguments; e.g., [:(b=2)]
61
+ default_kwargs = []
62
+ for arg in default_call. args[2 : end ]
63
+ if isa (arg, Symbol)
64
+ push! (required, arg)
65
+ elseif arg. head == :(:: )
66
+ push! (required, arg)
67
+ elseif arg. head == :kw
68
+ push! (optional, arg)
69
+ elseif arg. head == :parameters
70
+ @assert default_kwargs == [] # can I have :parameters twice?
71
+ default_kwargs = arg. args
72
+ else
73
+ error (" Not expecting to see: $arg " )
74
+ end
75
+ end
76
+ if isempty (required) && isempty (optional)
77
+ # If the function is already keyword-only, do nothing:
78
+ return nothing
79
+ end
80
+ if isempty (required)
81
+ # It's not clear what should be done. Let's not support it at
82
+ # the moment:
83
+ error (" At least one positional mandatory argument is required." )
84
+ end
85
+
86
+ kwonly_kwargs = Expr (:parameters , [
87
+ Expr (:kw , pa, :(error ($ (" No argument $pa " ))))
88
+ for pa in required
89
+ ]. .. , optional... , default_kwargs... )
90
+ kwonly_call = Expr (:call , funcname, kwonly_kwargs)
91
+ # e.g., :(f(; a=error(...), b=error(...), c=1, d=2))
92
+
93
+ return kwonly_call
94
+ end
95
+
96
+ function num_types_in_tuple (sig)
97
+ length (sig. parameters)
98
+ end
99
+
100
+ function num_types_in_tuple (sig:: UnionAll )
101
+ length (Base. unwrap_unionall (sig). parameters)
102
+ end
103
+
104
+ function numargs (f)
105
+ typ = Tuple{Any, Val{:analytic }, Vararg}
106
+ typ2 = Tuple{Any, Type{Val{:analytic }}, Vararg} # This one is required for overloaded types
107
+ typ3 = Tuple{Any, Val{:jac }, Vararg}
108
+ typ4 = Tuple{Any, Type{Val{:jac }}, Vararg} # This one is required for overloaded types
109
+ typ5 = Tuple{Any, Val{:tgrad }, Vararg}
110
+ typ6 = Tuple{Any, Type{Val{:tgrad }}, Vararg} # This one is required for overloaded types
111
+ numparam = maximum ([(m. sig<: typ || m. sig<: typ2 || m. sig<: typ3 || m. sig<: typ4 || m. sig<: typ5 || m. sig<: typ6 ) ? 0 : num_types_in_tuple (m. sig) for m in methods (f)])
112
+ return (numparam- 1 ) # -1 in v0.5 since it adds f as the first parameter
113
+ end
114
+
115
+ function isinplace (f,inplace_param_number)
116
+ numargs (f)>= inplace_param_number
117
+ end
118
+
119
+ # ## Default Linsolve
120
+
121
+ # Try to be as smart as possible
122
+ # lu! if Matrix
123
+ # lu if sparse
124
+ # gmres if operator
125
+
126
+ mutable struct DefaultLinSolve
127
+ A
128
+ iterable
129
+ end
130
+ DefaultLinSolve () = DefaultLinSolve (nothing , nothing )
131
+
132
+ function (p:: DefaultLinSolve )(x,A,b,update_matrix= false ;tol= nothing , kwargs... )
133
+ if p. iterable isa Vector && eltype (p. iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector
134
+ F = LU {eltype(A)} (A, p. iterable, zero (LinearAlgebra. BlasInt))
135
+ ldiv! (x, F, b)
136
+ return nothing
137
+ end
138
+ if update_matrix
139
+ if typeof (A) <: Matrix
140
+ blasvendor = BLAS. vendor ()
141
+ # if the user doesn't use OpenBLAS, we assume that is a better BLAS
142
+ # implementation like MKL
143
+ #
144
+ # RecursiveFactorization seems to be consistantly winning below 100
145
+ # https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213
146
+ if ArrayInterface. can_setindex (x) && (size (A,1 ) <= 100 || ((blasvendor === :openblas || blasvendor === :openblas64 ) && size (A,1 ) <= 500 ))
147
+ p. A = RecursiveFactorization. lu! (A)
148
+ else
149
+ p. A = lu! (A)
150
+ end
151
+ elseif typeof (A) <: Tridiagonal
152
+ p. A = lu! (A)
153
+ elseif typeof (A) <: Union{SymTridiagonal}
154
+ p. A = ldlt! (A)
155
+ elseif typeof (A) <: Union{Symmetric,Hermitian}
156
+ p. A = bunchkaufman! (A)
157
+ elseif typeof (A) <: SparseMatrixCSC
158
+ p. A = lu (A)
159
+ elseif ArrayInterface. isstructured (A)
160
+ p. A = factorize (A)
161
+ elseif ! (typeof (A) <: AbstractDiffEqOperator )
162
+ # Most likely QR is the one that is overloaded
163
+ # Works on things like CuArrays
164
+ p. A = qr (A)
165
+ end
166
+ end
167
+
168
+ if typeof (A) <: Union{Matrix,SymTridiagonal,Tridiagonal,Symmetric,Hermitian} # No 2-arg form for SparseArrays!
169
+ x .= b
170
+ ldiv! (p. A,x)
171
+ # Missing a little bit of efficiency in a rare case
172
+ # elseif typeof(A) <: DiffEqArrayOperator
173
+ # ldiv!(x,p.A,b)
174
+ elseif ArrayInterface. isstructured (A) || A isa SparseMatrixCSC
175
+ ldiv! (x,p. A,b)
176
+ elseif typeof (A) <: AbstractDiffEqOperator
177
+ # No good starting guess, so guess zero
178
+ if p. iterable === nothing
179
+ p. iterable = IterativeSolvers. gmres_iterable! (x,A,b;initially_zero= true ,restart= 5 ,maxiter= 5 ,tol= 1e-16 ,kwargs... )
180
+ p. iterable. reltol = tol
181
+ end
182
+ x .= false
183
+ iter = p. iterable
184
+ purge_history! (iter, x, b)
185
+
186
+ for residual in iter
187
+ end
188
+ else
189
+ ldiv! (x,p. A,b)
190
+ end
191
+ return nothing
192
+ end
193
+
194
+ function (p:: DefaultLinSolve )(:: Type{Val{:init}} ,f,u0_prototype)
195
+ DefaultLinSolve ()
196
+ end
197
+
198
+ const DEFAULT_LINSOLVE = DefaultLinSolve ()
199
+
200
+ @inline UNITLESS_ABS2 (x) = real (abs2 (x))
201
+ @inline DEFAULT_NORM (u:: Union{AbstractFloat,Complex} ) = @fastmath abs (u)
202
+ @inline DEFAULT_NORM (u:: Array{T} ) where T<: Union{AbstractFloat,Complex} =
203
+ sqrt (real (sum (abs2,u)) / length (u))
204
+ @inline DEFAULT_NORM (u:: StaticArray{T} ) where T<: Union{AbstractFloat,Complex} =
205
+ sqrt (real (sum (abs2,u)) / length (u))
206
+ @inline DEFAULT_NORM (u:: RecursiveArrayTools.AbstractVectorOfArray ) =
207
+ sum (sqrt (real (sum (UNITLESS_ABS2,_u)) / length (_u)) for _u in u. u)
208
+ @inline DEFAULT_NORM (u:: AbstractArray ) = sqrt (real (sum (UNITLESS_ABS2,u)) / length (u))
209
+ @inline DEFAULT_NORM (u) = norm (u)
210
+
1
211
"""
2
212
prevfloat_tdir(x, x0, x1)
3
213
@@ -24,6 +234,3 @@ function value_derivative(f::F, x::R) where {F,R}
24
234
out = f (ForwardDiff. Dual {T} (x, one (x)))
25
235
ForwardDiff. value (out), ForwardDiff. extract_derivative (T, out)
26
236
end
27
-
28
- DiffEqBase. has_Wfact (f:: Function ) = false
29
- DiffEqBase. has_Wfact_t (f:: Function ) = false
0 commit comments