Skip to content

Commit a01d733

Browse files
committed
What is the problem ????
1 parent c745094 commit a01d733

File tree

4 files changed

+102
-50
lines changed

4 files changed

+102
-50
lines changed

src/dyno/dynsym/autodiff.py

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
import math
2+
import numpy as np
3+
4+
5+
def _get_math_module(x):
6+
"""Return the appropriate math module (math or numpy) based on input type."""
7+
if isinstance(x, np.ndarray):
8+
return np
9+
else:
10+
return math
211

312

413
class DNumber:
@@ -93,11 +102,12 @@ def __pow__(self, power):
93102
if isinstance(power, DNumber):
94103
new_value = self.value**power.value
95104
new_derivatives = {}
105+
m = _get_math_module(self.value)
96106
for var in set(self.derivatives.keys()).union(power.derivatives.keys()):
97107
deriv1 = self.derivatives.get(var, 0)
98108
deriv2 = power.derivatives.get(var, 0)
99109
new_derivatives[var] = new_value * (
100-
deriv1 * power.value / self.value + deriv2 * math.log(self.value)
110+
deriv1 * power.value / self.value + deriv2 * m.log(self.value)
101111
)
102112
return DNumber(new_value, new_derivatives)
103113
else:
@@ -113,8 +123,9 @@ def __rpow__(self, base):
113123
return base.__pow__(self)
114124
else:
115125
new_value = base**self.value
126+
m = _get_math_module(self.value)
116127
new_derivatives = {
117-
var: deriv * new_value * math.log(base)
128+
var: deriv * new_value * m.log(base)
118129
for var, deriv in self.derivatives.items()
119130
}
120131
return DNumber(new_value, new_derivatives)
@@ -133,72 +144,79 @@ def __repr__(self):
133144
def sin(x):
134145
"""Sine function that works with both floats and DNumber objects."""
135146
if isinstance(x, DNumber):
136-
new_value = math.sin(x.value)
147+
new_value = sin(x.value)
137148
new_derivatives = {
138-
var: deriv * math.cos(x.value) for var, deriv in x.derivatives.items()
149+
var: deriv * cos(x.value) for var, deriv in x.derivatives.items()
139150
}
140151
return DNumber(new_value, new_derivatives)
141152
else:
142-
return math.sin(x)
153+
m = _get_math_module(x)
154+
return m.sin(x)
143155

144156

145157
def cos(x):
146158
"""Cosine function that works with both floats and DNumber objects."""
147159
if isinstance(x, DNumber):
148-
new_value = math.cos(x.value)
160+
new_value = cos(x.value)
149161
new_derivatives = {
150-
var: -deriv * math.sin(x.value) for var, deriv in x.derivatives.items()
162+
var: -deriv * sin(x.value) for var, deriv in x.derivatives.items()
151163
}
152164
return DNumber(new_value, new_derivatives)
153165
else:
154-
return math.cos(x)
166+
m = _get_math_module(x)
167+
return m.cos(x)
155168

156169

157170
def tan(x):
158171
"""Tangent function that works with both floats and DNumber objects."""
159172
if isinstance(x, DNumber):
160-
new_value = math.tan(x.value)
161-
sec_squared = 1 / (math.cos(x.value) ** 2)
173+
new_value = tan(x.value)
174+
sec_squared = 1 / (cos(x.value) ** 2)
162175
new_derivatives = {
163176
var: deriv * sec_squared for var, deriv in x.derivatives.items()
164177
}
165178
return DNumber(new_value, new_derivatives)
166179
else:
167-
return math.tan(x)
180+
m = _get_math_module(x)
181+
return m.tan(x)
168182

169183

170184
def exp(x):
171185
"""Exponential function that works with both floats and DNumber objects."""
186+
from rich import print
172187
if isinstance(x, DNumber):
173-
new_value = math.exp(x.value)
188+
new_value = exp(x.value)
174189
new_derivatives = {
175190
var: deriv * new_value for var, deriv in x.derivatives.items()
176191
}
177192
return DNumber(new_value, new_derivatives)
178193
else:
179-
return math.exp(x)
194+
m = _get_math_module(x)
195+
return m.exp(x)
180196

181197

182198
def log(x):
183199
"""Natural logarithm function that works with both floats and DNumber objects."""
184200
if isinstance(x, DNumber):
185-
new_value = math.log(x.value)
201+
new_value = log(x.value)
186202
new_derivatives = {var: deriv / x.value for var, deriv in x.derivatives.items()}
187203
return DNumber(new_value, new_derivatives)
188204
else:
189-
return math.log(x)
205+
m = _get_math_module(x)
206+
return m.log(x)
190207

191208

192209
def sqrt(x):
193210
"""Square root function that works with both floats and DNumber objects."""
194211
if isinstance(x, DNumber):
195-
new_value = math.sqrt(x.value)
212+
new_value = sqrt(x.value)
196213
new_derivatives = {
197214
var: deriv / (2 * new_value) for var, deriv in x.derivatives.items()
198215
}
199216
return DNumber(new_value, new_derivatives)
200217
else:
201-
return math.sqrt(x)
218+
m = _get_math_module(x)
219+
return m.sqrt(x)
202220

203221

204222
def dabs(x):
@@ -215,77 +233,83 @@ def dabs(x):
215233
def sinh(x):
216234
"""Hyperbolic sine function that works with both floats and DNumber objects."""
217235
if isinstance(x, DNumber):
218-
new_value = math.sinh(x.value)
236+
new_value = sinh(x.value)
219237
new_derivatives = {
220-
var: deriv * math.cosh(x.value) for var, deriv in x.derivatives.items()
238+
var: deriv * cosh(x.value) for var, deriv in x.derivatives.items()
221239
}
222240
return DNumber(new_value, new_derivatives)
223241
else:
224-
return math.sinh(x)
242+
m = _get_math_module(x)
243+
return m.sinh(x)
225244

226245

227246
def cosh(x):
228247
"""Hyperbolic cosine function that works with both floats and DNumber objects."""
229248
if isinstance(x, DNumber):
230-
new_value = math.cosh(x.value)
249+
new_value = cosh(x.value)
231250
new_derivatives = {
232-
var: deriv * math.sinh(x.value) for var, deriv in x.derivatives.items()
251+
var: deriv * sinh(x.value) for var, deriv in x.derivatives.items()
233252
}
234253
return DNumber(new_value, new_derivatives)
235254
else:
236-
return math.cosh(x)
255+
m = _get_math_module(x)
256+
return m.cosh(x)
237257

238258

239259
def tanh(x):
240260
"""Hyperbolic tangent function that works with both floats and DNumber objects."""
241261
if isinstance(x, DNumber):
242-
new_value = math.tanh(x.value)
262+
new_value = tanh(x.value)
243263
sech_squared = 1 - new_value**2
244264
new_derivatives = {
245265
var: deriv * sech_squared for var, deriv in x.derivatives.items()
246266
}
247267
return DNumber(new_value, new_derivatives)
248268
else:
249-
return math.tanh(x)
269+
m = _get_math_module(x)
270+
return m.tanh(x)
250271

251272

252273
def asin(x):
253274
"""Arcsine function that works with both floats and DNumber objects."""
254275
if isinstance(x, DNumber):
255-
new_value = math.asin(x.value)
256-
derivative_factor = 1 / math.sqrt(1 - x.value**2)
276+
new_value = asin(x.value)
277+
derivative_factor = 1 / sqrt(1 - x.value**2)
257278
new_derivatives = {
258279
var: deriv * derivative_factor for var, deriv in x.derivatives.items()
259280
}
260281
return DNumber(new_value, new_derivatives)
261282
else:
262-
return math.asin(x)
283+
m = _get_math_module(x)
284+
return np.arcsin(x) if m is np else math.asin(x)
263285

264286

265287
def acos(x):
266288
"""Arccosine function that works with both floats and DNumber objects."""
267289
if isinstance(x, DNumber):
268-
new_value = math.acos(x.value)
269-
derivative_factor = -1 / math.sqrt(1 - x.value**2)
290+
new_value = acos(x.value)
291+
derivative_factor = -1 / sqrt(1 - x.value**2)
270292
new_derivatives = {
271293
var: deriv * derivative_factor for var, deriv in x.derivatives.items()
272294
}
273295
return DNumber(new_value, new_derivatives)
274296
else:
275-
return math.acos(x)
297+
m = _get_math_module(x)
298+
return np.arccos(x) if m is np else math.acos(x)
276299

277300

278301
def atan(x):
279302
"""Arctangent function that works with both floats and DNumber objects."""
280303
if isinstance(x, DNumber):
281-
new_value = math.atan(x.value)
304+
new_value = atan(x.value)
282305
derivative_factor = 1 / (1 + x.value**2)
283306
new_derivatives = {
284307
var: deriv * derivative_factor for var, deriv in x.derivatives.items()
285308
}
286309
return DNumber(new_value, new_derivatives)
287310
else:
288-
return math.atan(x)
311+
m = _get_math_module(x)
312+
return np.arctan(x) if m is np else math.atan(x)
289313

290314

291315
def dmax(x, y):
@@ -325,48 +349,52 @@ def dmin(x, y):
325349
def log10(x):
326350
"""Base-10 logarithm function that works with both floats and DNumber objects."""
327351
if isinstance(x, DNumber):
328-
new_value = math.log10(x.value)
352+
new_value = log10(x.value)
329353
new_derivatives = {
330-
var: deriv / (x.value * math.log(10))
354+
var: deriv / (x.value * log(10))
331355
for var, deriv in x.derivatives.items()
332356
}
333357
return DNumber(new_value, new_derivatives)
334358
else:
335-
return math.log10(x)
359+
m = _get_math_module(x)
360+
return m.log10(x)
336361

337362

338363
def log2(x):
339364
"""Base-2 logarithm function that works with both floats and DNumber objects."""
340365
if isinstance(x, DNumber):
341-
new_value = math.log2(x.value)
366+
new_value = log2(x.value)
342367
new_derivatives = {
343-
var: deriv / (x.value * math.log(2)) for var, deriv in x.derivatives.items()
368+
var: deriv / (x.value * log(2)) for var, deriv in x.derivatives.items()
344369
}
345370
return DNumber(new_value, new_derivatives)
346371
else:
347-
return math.log2(x)
372+
m = _get_math_module(x)
373+
return m.log2(x)
348374

349375

350376
def floor(x):
351377
"""Floor function that works with both floats and DNumber objects."""
352378
if isinstance(x, DNumber):
353-
new_value = math.floor(x.value)
379+
new_value = floor(x.value)
354380
# Derivative of floor is 0 everywhere except at integer points (where it's undefined)
355381
new_derivatives = {var: 0.0 for var in x.derivatives.keys()}
356382
return DNumber(new_value, new_derivatives)
357383
else:
358-
return math.floor(x)
384+
m = _get_math_module(x)
385+
return m.floor(x)
359386

360387

361388
def ceil(x):
362389
"""Ceiling function that works with both floats and DNumber objects."""
363390
if isinstance(x, DNumber):
364-
new_value = math.ceil(x.value)
391+
new_value = ceil(x.value)
365392
# Derivative of ceil is 0 everywhere except at integer points (where it's undefined)
366393
new_derivatives = {var: 0.0 for var in x.derivatives.keys()}
367394
return DNumber(new_value, new_derivatives)
368395
else:
369-
return math.ceil(x)
396+
m = _get_math_module(x)
397+
return m.ceil(x)
370398

371399

372400
def pow(x, y):

src/dyno/model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,16 @@ def _repr_html_(self):
219219
</ul>
220220
"""
221221

222-
def solve(self: Self, method: Solver = "qz") -> RecursiveSolution:
222+
def solve(self: Self):
223+
224+
if self.checks['deterministic']:
225+
from .solver import deterministic_solve
226+
sol = deterministic_solve(self)
227+
return sol
228+
else:
229+
dr = self.perturb()
230+
231+
def perturb(self: Self, method: Solver = "qz") -> RecursiveSolution:
223232
"""linearizes the model
224233
225234
Parameters

src/dyno/solver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def moments(X: TMatrix, Y: TMatrix, Σ: TMatrix) -> tuple[TMatrix, TMatrix]:
393393

394394

395395
import time
396-
396+
old_print = print
397397

398398
def newton(f, x, verbose=False, tol=1e-6, maxit=5, jactype="serial"):
399399
"""Solve nonlinear system using safeguarded Newton iterations
@@ -470,7 +470,7 @@ def newton(f, x, verbose=False, tol=1e-6, maxit=5, jactype="serial"):
470470
warnings.warn("Did not converge")
471471
return [x, it]
472472

473-
def deterministic_solve(model, x0=None, T=None, method="hybr"):
473+
def deterministic_solve(model, x0=None, T=None, method="hybr", verbose=True):
474474

475475
import scipy.optimize
476476
import pandas
@@ -492,10 +492,13 @@ def deterministic_solve(model, x0=None, T=None, method="hybr"):
492492
# w0 = res.x.reshape(v0.shape)
493493

494494
u0 = np.array(v0).ravel()
495+
496+
495497
res, nit = newton(
496498
lambda u: model.deterministic_residuals_with_jacobian(u, sparsify=True),
497499
u0,
498500
jactype="sparse",
501+
verbose=verbose,
499502
)
500503

501504
w0 = res.reshape(v0.shape)

src/dyno/symbolic_model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,28 @@ def latex_equations(self):
4949

5050
def compute_residuals(self, y2, y1, y0, e):
5151

52-
fe = self.data.evaluator
52+
import math
5353
endogenous = self.symbols["endogenous"]
5454
exogenous = self.symbols["exogenous"]
5555

56+
import copy
57+
cc = copy.deepcopy(self.data.context)
58+
5659
for i, name in enumerate(endogenous):
57-
fe.variables[name] = {-1: y0[i], 0: y1[i], 1: y2[i]}
60+
cc['variables'][name] = {
61+
-1: cc['steady_states'].get(name, math.nan),
62+
0: cc['steady_states'].get(name, math.nan),
63+
1: cc['steady_states'].get(name, math.nan),
64+
}
65+
5866
for i, name in enumerate(exogenous):
59-
fe.variables[name] = {0: e[i]}
67+
cc['variables'][name] = {0: 0.0}
6068

61-
results = [fe.visit(eq) for eq in fe.equations]
69+
70+
from dyno.dynsym.analyze import EquationsEvaluator
71+
E = EquationsEvaluator(cc)
72+
73+
results = [E.visit(eq) for eq in self.data.equations]
6274

6375
r = np.array([float(el) for el in results])
6476

0 commit comments

Comments
 (0)