This repository was archived by the owner on Nov 17, 2025. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 6
Initialize RaveledParamsMap with dictionaries #62
Copy link
Copy link
Open
Labels
enhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed
Description
Currently one has to pass an iterable (that is then converted to a tuple) to initialize RaveledParamsMap:
import aesara as at
from aehmc.utils import RaveledParamsMap
tau_vv = at.vector("tau")
lambda_vv = at.vector("lambda")
rp_map = RaveledParamsMap((tau_vv, lambda_vv))
q = rp_map.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map.unravel_params(q)[tau_vv]
lambda_part = rp_map.unravel_params(q)[lambda_vv]In some circumstances we need the map to be indexed with other variables. For instance when we work with transformed variables and need the map to link the original value variables to the transformed variables (which may have different shapes/dtypes). In this case we need to overwrite the ref_params property:
from aeppl.transforms import LogTransform
lambda_vv_trans = LogTransform().forward(lambda_vv)
rp_map_trans = RaveledParamsMap((tau_vv, lambda_vv_trans))
rp_map_trans.ref_params = (tau_vv, lambda_vv)
q = rp_map_trans.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map_trans.unravel_params(q)[tau_vv]
lambda_trans_part = rp_map_trans.unravel_params(q)[lambda_vv]I suggest to simplify this by allowing the RaveledParamsMap to be initialized with a dictionary:
rp_map_trans = RaveledParamsMap({tau_vv: tau_vv, lambda_vv: lambda_vv_trans})Shapes and dtypes are infered from the dictionaries' values, the map is indexed by the dictionaries' keys.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed