3131import inspect
3232import textwrap
3333from typing import Any
34+ from typing import Callable
3435from typing import Dict
3536from typing import List
3637from typing import Tuple
3738
39+ import numpy as np
3840import sympy as sp
41+ from dysts .base import BaseDyn
3942
4043
4144def _is_name (node : ast .AST , name : str ) -> bool :
@@ -181,10 +184,6 @@ def visit_Subscript(self, node: ast.Subscript):
181184 idx_node = node .slice
182185 if isinstance (idx_node , ast .Constant ):
183186 idx = idx_node .value
184- elif isinstance (idx_node , ast .Index ) and isinstance (
185- idx_node .value , ast .Constant
186- ):
187- idx = idx_node .value .value
188187 else :
189188 raise NotImplementedError (
190189 "Only constant indices into state vector supported"
@@ -200,16 +199,76 @@ def visit_Subscript(self, node: ast.Subscript):
200199 raise NotImplementedError ("Unsupported subscript pattern" )
201200
202201
203- def object_to_sympy_rhs (
202+ def _numeric_consistency_check (
203+ dysts_flow : BaseDyn ,
204+ rhsfunc : Callable ,
205+ arg_names : List [str ],
206+ state_names : List [str ],
207+ vector_mode : bool ,
208+ sys_dim : int ,
209+ lambda_rhs : sp .Lambda ,
210+ ) -> None :
211+ """Compare the original dysts rhs function to the SymPy-derived lambda.
212+
213+ Raises a RuntimeError if they disagree.
214+ """
215+ # default to nonnegative support (e.g. Lotka volterra)
216+ random_state = np .random .standard_exponential (size = sys_dim )
217+
218+ # Construct call arguments for the original function (bound method).
219+ call_args = []
220+ for name in arg_names :
221+ if name == "self" :
222+ continue
223+ if name in state_names and not vector_mode :
224+ idx = state_names .index (name )
225+ call_args .append (random_state [idx ])
226+ elif name in state_names and vector_mode :
227+ call_args .append (np .asarray (random_state , dtype = float ))
228+ elif name == "t" :
229+ call_args .append (float (np .random .standard_normal (size = ())))
230+ else :
231+ call_args .append (dysts_flow .params [name ])
232+
233+ dysts_val = rhsfunc (* call_args )
234+ orig_arr = np .asarray (dysts_val , dtype = float ).ravel ()
235+
236+ sym_val = lambda_rhs (* tuple (random_state ))
237+ sym_arr = np .asarray (sym_val , dtype = float ).ravel ()
238+
239+ if orig_arr .shape != sym_arr .shape :
240+ raise RuntimeError (
241+ f"_rhs shape { orig_arr .shape } != sympy shape { sym_arr .shape } "
242+ )
243+
244+ if not np .allclose (orig_arr , sym_arr , rtol = 1e-6 , atol = 1e-9 ):
245+ raise RuntimeError ("Numeric mismatch between original and sympy conversion." )
246+
247+
248+ def dynsys_to_sympy (
204249 obj : Any , func_name : str = "_rhs"
205250) -> Tuple [List [sp .Symbol ], List [sp .Expr ], sp .Lambda ]:
206251 """Inspect ``obj`` for a method named ``func_name`` and return a SymPy
207252 representation of its RHS.
208253
209- Returns a tuple ``(state_symbols, exprs, lambda_rhs)`` where ``state_symbols``
254+ Returns:
255+ a tuple ``(state_symbols, exprs, lambda_rhs)`` where ``state_symbols``
210256 is a list of SymPy symbols for the state vector, ``exprs`` is a list of
211257 SymPy expressions for the RHS components, and ``lambda_rhs`` is a SymPy
212258 Lambda mapping the state symbols to the RHS vector.
259+
260+ Example:
261+
262+ >>> from dysts.flows import Lorenz
263+ >>> from inspect_to_sympy import dynsys_to_sympy
264+ >>> lor = Lorenz()
265+ >>> symbols, exprs, lambda_rhs = dynsys_to_sympy(lor)
266+ >>> print(lor._rhs(1, 2, 3, t=0.0, **lor.params))
267+ (10, 23, -6.0009999999999994)
268+
269+ >>> print(tuple(lambda_rhs(1, 2, 3)))
270+ (10, 23, -6.00100000000000)
271+
213272 """
214273
215274 if not hasattr (obj , func_name ):
@@ -241,7 +300,7 @@ def object_to_sympy_rhs(
241300 start_idx = 1
242301
243302 vector_mode = False
244- state_arg_names : List [str ]
303+ state_args : List [str ]
245304 t_idx = None
246305 if "t" in arg_names :
247306 t_idx = arg_names .index ("t" )
@@ -251,33 +310,33 @@ def object_to_sympy_rhs(
251310 if t_idx is not None :
252311 potential = arg_names [start_idx :t_idx ]
253312 if len (potential ) >= n_state :
254- state_arg_names = potential [:n_state ]
313+ state_args = potential [:n_state ]
255314 else :
256- state_arg_names = [arg_names [start_idx ]]
315+ state_args = [arg_names [start_idx ]]
257316 vector_mode = True
258317 else :
259318 potential = arg_names [start_idx :]
260319 if len (potential ) >= n_state :
261- state_arg_names = potential [:n_state ]
320+ state_args = potential [:n_state ]
262321 else :
263- state_arg_names = [arg_names [start_idx ]]
322+ state_args = [arg_names [start_idx ]]
264323 vector_mode = True
265324 else :
266325 if t_idx is not None :
267- state_arg_names = arg_names [start_idx :t_idx ]
268- if len (state_arg_names ) == 0 :
269- state_arg_names = [arg_names [start_idx ]]
326+ state_args = arg_names [start_idx :t_idx ]
327+ if len (state_args ) == 0 :
328+ state_args = [arg_names [start_idx ]]
270329 vector_mode = True
271- elif len (state_arg_names ) == 1 :
330+ elif len (state_args ) == 1 :
272331 # single name could be vector or scalar; assume vector-mode
273332 vector_mode = True
274333 else :
275- state_arg_names = [arg_names [start_idx ]]
334+ state_args = [arg_names [start_idx ]]
276335 vector_mode = True
277336
278337 # If vector_mode, inspect AST for subscript/index usage or tuple unpacking
279338 if vector_mode :
280- state_name = state_arg_names [0 ]
339+ state_name = state_args [0 ]
281340 max_index = - 1
282341 unpack_size = None
283342 for node in ast .walk (fndef ):
@@ -290,13 +349,6 @@ def object_to_sympy_rhs(
290349 if isinstance (sl , ast .Constant ) and isinstance (sl .value , int ):
291350 if sl .value > max_index :
292351 max_index = sl .value
293- elif (
294- isinstance (sl , ast .Index )
295- and isinstance (sl .value , ast .Constant )
296- and isinstance (sl .value .value , int )
297- ):
298- if sl .value .value > max_index :
299- max_index = sl .value .value
300352 if isinstance (node , ast .Assign ):
301353 if isinstance (node .value , ast .Name ) and node .value .id == state_name :
302354 targets = node .targets
@@ -316,12 +368,12 @@ def object_to_sympy_rhs(
316368 primary_state_name = state_name
317369 else :
318370 # individual state args -> use their arg names as symbol names
319- state_symbols = [sp .Symbol (n ) for n in state_arg_names ]
320- primary_state_name = state_arg_names [0 ] if len (state_arg_names ) > 0 else "x"
371+ state_symbols = [sp .Symbol (n ) for n in state_args ]
372+ primary_state_name = state_args [0 ] if len (state_args ) > 0 else "x"
321373
322374 # Build locals mapping from known state arg names and parameters
323375 locals_map : Dict [str , Any ] = {}
324- for i , name in enumerate (state_arg_names ):
376+ for i , name in enumerate (state_args ):
325377 if i < len (state_symbols ):
326378 locals_map [name ] = state_symbols [i ]
327379
@@ -364,7 +416,6 @@ def object_to_sympy_rhs(
364416 return_expr = stmt .value
365417
366418 if return_expr is None :
367- raise ValueError ("No return statement found in function" )
368419 # maybe last statement is an Expr with list construction;
369420 # try to find a Return node deep
370421 for node in ast .walk (fndef ):
@@ -379,7 +430,6 @@ def object_to_sympy_rhs(
379430 converter = _ASTToSympy (primary_state_name , state_symbols , locals_map )
380431 rhs_val = converter .visit (return_expr )
381432
382- # Normalize rhs_val to a list/tuple of expressions
383433 if isinstance (rhs_val , (list , tuple )):
384434 exprs = list (rhs_val )
385435 else :
@@ -388,7 +438,18 @@ def object_to_sympy_rhs(
388438
389439 lambda_rhs = sp .Lambda (tuple (state_symbols ), sp .Matrix (exprs ))
390440
441+ # Run numeric consistency guard (raises on mismatch)
442+ _numeric_consistency_check (
443+ obj ,
444+ func ,
445+ arg_names ,
446+ state_args ,
447+ vector_mode ,
448+ len (state_symbols ),
449+ lambda_rhs ,
450+ )
451+
391452 return state_symbols , exprs , lambda_rhs
392453
393454
394- __all__ = ["object_to_sympy_rhs " ]
455+ __all__ = ["dynsys_to_sympy " ]
0 commit comments