Skip to content

Commit e84f8dc

Browse files
feat: Make sympy conversion doublecheck
1 parent 9c87f1a commit e84f8dc

File tree

2 files changed

+93
-32
lines changed

2 files changed

+93
-32
lines changed

asv_bench/benchmarks/inspect_to_sympy.py

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@
3131
import inspect
3232
import textwrap
3333
from typing import Any
34+
from typing import Callable
3435
from typing import Dict
3536
from typing import List
3637
from typing import Tuple
3738

39+
import numpy as np
3840
import sympy as sp
41+
from dysts.base import BaseDyn
3942

4043

4144
def _is_name(node: ast.AST, name: str) -> bool:
@@ -181,10 +184,6 @@ def visit_Subscript(self, node: ast.Subscript):
181184
idx_node = node.slice
182185
if isinstance(idx_node, ast.Constant):
183186
idx = idx_node.value
184-
elif isinstance(idx_node, ast.Index) and isinstance(
185-
idx_node.value, ast.Constant
186-
):
187-
idx = idx_node.value.value
188187
else:
189188
raise NotImplementedError(
190189
"Only constant indices into state vector supported"
@@ -200,16 +199,76 @@ def visit_Subscript(self, node: ast.Subscript):
200199
raise NotImplementedError("Unsupported subscript pattern")
201200

202201

203-
def object_to_sympy_rhs(
202+
def _numeric_consistency_check(
203+
dysts_flow: BaseDyn,
204+
rhsfunc: Callable,
205+
arg_names: List[str],
206+
state_names: List[str],
207+
vector_mode: bool,
208+
sys_dim: int,
209+
lambda_rhs: sp.Lambda,
210+
) -> None:
211+
"""Compare the original dysts rhs function to the SymPy-derived lambda.
212+
213+
Raises a RuntimeError if they disagree.
214+
"""
215+
# default to nonnegative support (e.g. Lotka volterra)
216+
random_state = np.random.standard_exponential(size=sys_dim)
217+
218+
# Construct call arguments for the original function (bound method).
219+
call_args = []
220+
for name in arg_names:
221+
if name == "self":
222+
continue
223+
if name in state_names and not vector_mode:
224+
idx = state_names.index(name)
225+
call_args.append(random_state[idx])
226+
elif name in state_names and vector_mode:
227+
call_args.append(np.asarray(random_state, dtype=float))
228+
elif name == "t":
229+
call_args.append(float(np.random.standard_normal(size=())))
230+
else:
231+
call_args.append(dysts_flow.params[name])
232+
233+
dysts_val = rhsfunc(*call_args)
234+
orig_arr = np.asarray(dysts_val, dtype=float).ravel()
235+
236+
sym_val = lambda_rhs(*tuple(random_state))
237+
sym_arr = np.asarray(sym_val, dtype=float).ravel()
238+
239+
if orig_arr.shape != sym_arr.shape:
240+
raise RuntimeError(
241+
f"_rhs shape {orig_arr.shape} != sympy shape {sym_arr.shape}"
242+
)
243+
244+
if not np.allclose(orig_arr, sym_arr, rtol=1e-6, atol=1e-9):
245+
raise RuntimeError("Numeric mismatch between original and sympy conversion.")
246+
247+
248+
def dynsys_to_sympy(
204249
obj: Any, func_name: str = "_rhs"
205250
) -> Tuple[List[sp.Symbol], List[sp.Expr], sp.Lambda]:
206251
"""Inspect ``obj`` for a method named ``func_name`` and return a SymPy
207252
representation of its RHS.
208253
209-
Returns a tuple ``(state_symbols, exprs, lambda_rhs)`` where ``state_symbols``
254+
Returns:
255+
a tuple ``(state_symbols, exprs, lambda_rhs)`` where ``state_symbols``
210256
is a list of SymPy symbols for the state vector, ``exprs`` is a list of
211257
SymPy expressions for the RHS components, and ``lambda_rhs`` is a SymPy
212258
Lambda mapping the state symbols to the RHS vector.
259+
260+
Example:
261+
262+
>>> from dysts.flows import Lorenz
263+
>>> from inspect_to_sympy import dynsys_to_sympy
264+
>>> lor = Lorenz()
265+
>>> symbols, exprs, lambda_rhs = dynsys_to_sympy(lor)
266+
>>> print(lor._rhs(1, 2, 3, t=0.0, **lor.params))
267+
(10, 23, -6.0009999999999994)
268+
269+
>>> print(tuple(lambda_rhs(1, 2, 3)))
270+
(10, 23, -6.00100000000000)
271+
213272
"""
214273

215274
if not hasattr(obj, func_name):
@@ -241,7 +300,7 @@ def object_to_sympy_rhs(
241300
start_idx = 1
242301

243302
vector_mode = False
244-
state_arg_names: List[str]
303+
state_args: List[str]
245304
t_idx = None
246305
if "t" in arg_names:
247306
t_idx = arg_names.index("t")
@@ -251,33 +310,33 @@ def object_to_sympy_rhs(
251310
if t_idx is not None:
252311
potential = arg_names[start_idx:t_idx]
253312
if len(potential) >= n_state:
254-
state_arg_names = potential[:n_state]
313+
state_args = potential[:n_state]
255314
else:
256-
state_arg_names = [arg_names[start_idx]]
315+
state_args = [arg_names[start_idx]]
257316
vector_mode = True
258317
else:
259318
potential = arg_names[start_idx:]
260319
if len(potential) >= n_state:
261-
state_arg_names = potential[:n_state]
320+
state_args = potential[:n_state]
262321
else:
263-
state_arg_names = [arg_names[start_idx]]
322+
state_args = [arg_names[start_idx]]
264323
vector_mode = True
265324
else:
266325
if t_idx is not None:
267-
state_arg_names = arg_names[start_idx:t_idx]
268-
if len(state_arg_names) == 0:
269-
state_arg_names = [arg_names[start_idx]]
326+
state_args = arg_names[start_idx:t_idx]
327+
if len(state_args) == 0:
328+
state_args = [arg_names[start_idx]]
270329
vector_mode = True
271-
elif len(state_arg_names) == 1:
330+
elif len(state_args) == 1:
272331
# single name could be vector or scalar; assume vector-mode
273332
vector_mode = True
274333
else:
275-
state_arg_names = [arg_names[start_idx]]
334+
state_args = [arg_names[start_idx]]
276335
vector_mode = True
277336

278337
# If vector_mode, inspect AST for subscript/index usage or tuple unpacking
279338
if vector_mode:
280-
state_name = state_arg_names[0]
339+
state_name = state_args[0]
281340
max_index = -1
282341
unpack_size = None
283342
for node in ast.walk(fndef):
@@ -290,13 +349,6 @@ def object_to_sympy_rhs(
290349
if isinstance(sl, ast.Constant) and isinstance(sl.value, int):
291350
if sl.value > max_index:
292351
max_index = sl.value
293-
elif (
294-
isinstance(sl, ast.Index)
295-
and isinstance(sl.value, ast.Constant)
296-
and isinstance(sl.value.value, int)
297-
):
298-
if sl.value.value > max_index:
299-
max_index = sl.value.value
300352
if isinstance(node, ast.Assign):
301353
if isinstance(node.value, ast.Name) and node.value.id == state_name:
302354
targets = node.targets
@@ -316,12 +368,12 @@ def object_to_sympy_rhs(
316368
primary_state_name = state_name
317369
else:
318370
# individual state args -> use their arg names as symbol names
319-
state_symbols = [sp.Symbol(n) for n in state_arg_names]
320-
primary_state_name = state_arg_names[0] if len(state_arg_names) > 0 else "x"
371+
state_symbols = [sp.Symbol(n) for n in state_args]
372+
primary_state_name = state_args[0] if len(state_args) > 0 else "x"
321373

322374
# Build locals mapping from known state arg names and parameters
323375
locals_map: Dict[str, Any] = {}
324-
for i, name in enumerate(state_arg_names):
376+
for i, name in enumerate(state_args):
325377
if i < len(state_symbols):
326378
locals_map[name] = state_symbols[i]
327379

@@ -364,7 +416,6 @@ def object_to_sympy_rhs(
364416
return_expr = stmt.value
365417

366418
if return_expr is None:
367-
raise ValueError("No return statement found in function")
368419
# maybe last statement is an Expr with list construction;
369420
# try to find a Return node deep
370421
for node in ast.walk(fndef):
@@ -379,7 +430,6 @@ def object_to_sympy_rhs(
379430
converter = _ASTToSympy(primary_state_name, state_symbols, locals_map)
380431
rhs_val = converter.visit(return_expr)
381432

382-
# Normalize rhs_val to a list/tuple of expressions
383433
if isinstance(rhs_val, (list, tuple)):
384434
exprs = list(rhs_val)
385435
else:
@@ -388,7 +438,18 @@ def object_to_sympy_rhs(
388438

389439
lambda_rhs = sp.Lambda(tuple(state_symbols), sp.Matrix(exprs))
390440

441+
# Run numeric consistency guard (raises on mismatch)
442+
_numeric_consistency_check(
443+
obj,
444+
func,
445+
arg_names,
446+
state_args,
447+
vector_mode,
448+
len(state_symbols),
449+
lambda_rhs,
450+
)
451+
391452
return state_symbols, exprs, lambda_rhs
392453

393454

394-
__all__ = ["object_to_sympy_rhs"]
455+
__all__ = ["dynsys_to_sympy"]

test/test_inspect_to_sympy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
"dysts not available; skipping inspect_to_sympy tests", allow_module_level=True
88
)
99

10-
from asv_bench.benchmarks.inspect_to_sympy import object_to_sympy_rhs
10+
from asv_bench.benchmarks.inspect_to_sympy import dynsys_to_sympy
1111

1212

1313
def test_lorenz_to_sympy():
1414
lor = Lorenz()
15-
symbols, exprs, lambda_rhs = object_to_sympy_rhs(lor, func_name="_rhs")
15+
symbols, exprs, lambda_rhs = dynsys_to_sympy(lor, func_name="_rhs")
1616
assert len(symbols) == lor.dimension
1717
# evaluate lambda with simple numeric values
1818
vals = tuple(float(i + 1) for i in range(lor.dimension))

0 commit comments

Comments
 (0)