Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Initialize RaveledParamsMap with dictionaries #62

@rlouf

Description

@rlouf

Currently one has to pass an iterable (that is then converted to a tuple) to initialize RaveledParamsMap:

import aesara as at
from aehmc.utils import RaveledParamsMap

tau_vv = at.vector("tau")
lambda_vv = at.vector("lambda")

rp_map = RaveledParamsMap((tau_vv, lambda_vv))

q = rp_map.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map.unravel_params(q)[tau_vv]
lambda_part = rp_map.unravel_params(q)[lambda_vv]

In some circumstances we need the map to be indexed with other variables. For instance when we work with transformed variables and need the map to link the original value variables to the transformed variables (which may have different shapes/dtypes). In this case we need to overwrite the ref_params property:

from aeppl.transforms import LogTransform

lambda_vv_trans = LogTransform().forward(lambda_vv)

rp_map_trans = RaveledParamsMap((tau_vv, lambda_vv_trans))
rp_map_trans.ref_params = (tau_vv, lambda_vv)

q = rp_map_trans.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map_trans.unravel_params(q)[tau_vv]
lambda_trans_part = rp_map_trans.unravel_params(q)[lambda_vv]

I suggest to simplify this by allowing the RaveledParamsMap to be initialized with a dictionary:

rp_map_trans = RaveledParamsMap({tau_vv: tau_vv, lambda_vv: lambda_vv_trans})

Shapes and dtypes are infered from the dictionaries' values, the map is indexed by the dictionaries' keys.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requesthelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions