diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 9220048c7..ca0576b87 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict from collections.abc import Callable from typing import Any, Protocol, runtime_checkable @@ -15,7 +14,7 @@ ModelT: TypeAlias = Callable[P, Any] Message: TypeAlias = dict[str, Any] -TraceT: TypeAlias = OrderedDict[str, Message] +TraceT: TypeAlias = dict[str, Message] @runtime_checkable diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 901093768..82102656a 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict from functools import partial from typing import Callable, Optional @@ -493,9 +492,7 @@ def g(*args, **kwargs): dim_to_name = msg["infer"].get("dim_to_name") to_funsor( msg["value"], - dim_to_name=OrderedDict( - [(k, dim_to_name[k]) for k in sorted(dim_to_name)] - ), + dim_to_name={k: dim_to_name[k] for k in sorted(dim_to_name)}, ) apply_stack(msg) diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index db074b8ef..3a607c902 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict, defaultdict +from collections import defaultdict import functools from jax import random @@ -35,7 +35,7 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs): @_get_support_value.register(funsor.delta.Delta) def _get_support_value_delta(funsor_dist, name, **kwargs): assert name in funsor_dist.fresh - return OrderedDict(funsor_dist.terms)[name][0] + return dict(funsor_dist.terms)[name][0] def _sample_posterior( diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index 025d4931d..53dc62107 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict, namedtuple +from collections import namedtuple from contextlib import ExitStack # python 3 from enum import Enum @@ -71,8 +71,8 @@ class DimStack: def __init__(self): self._stack = [ StackFrame( - name_to_dim=OrderedDict(), - dim_to_name=OrderedDict(), + name_to_dim={}, + dim_to_name={}, parents=(), iter_parents=(), keep=False, @@ -238,7 +238,7 @@ def process_message(self, msg): @staticmethod def _get_name_to_dim(batch_names, name_to_dim=None, dim_type=DimType.LOCAL): - name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim.copy() + name_to_dim = {} if name_to_dim is None else name_to_dim.copy() # interpret all names/dims as requests since we only run this function once for name in batch_names: @@ -256,7 +256,7 @@ def _get_name_to_dim(batch_names, name_to_dim=None, dim_type=DimType.LOCAL): @classmethod # only depends on the global _DIM_STACK state, not self def _pyro_to_data(cls, msg): (funsor_value,) = msg["args"] - name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict()) + name_to_dim = msg["kwargs"].setdefault("name_to_dim", {}) dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL) batch_names = tuple(funsor_value.inputs.keys()) @@ -270,7 +270,7 @@ def _pyro_to_data(cls, msg): @staticmethod def _get_dim_to_name(batch_shape, dim_to_name=None, dim_type=DimType.LOCAL): - dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name.copy() + dim_to_name = {} if dim_to_name is None else dim_to_name.copy() batch_dim = len(batch_shape) # interpret all names/dims as requests since we only run this function once @@ -296,7 +296,7 @@ def _pyro_to_funsor(cls, msg): else: raw_value = msg["args"][0] output = msg["kwargs"].setdefault("output", None) - dim_to_name = msg["kwargs"].setdefault("dim_to_name", OrderedDict()) + dim_to_name = msg["kwargs"].setdefault("dim_to_name", {}) dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL) event_dim = len(output.shape) if output else 0 @@ -359,7 +359,7 @@ def __enter__(self): saved_frame = self._saved_frames.pop() name_to_dim, dim_to_name = saved_frame.name_to_dim, saved_frame.dim_to_name else: - name_to_dim, dim_to_name = OrderedDict(), OrderedDict() + name_to_dim, dim_to_name = {}, {} frame = StackFrame( name_to_dim=name_to_dim, @@ -490,18 +490,14 @@ def __init__(self, name, size, subsample_size=None, dim=None): self.subsample_size = indices.shape[0] self._indices = funsor.Tensor( indices, - OrderedDict([(self.name, funsor.Bint[self.subsample_size])]), + {self.name: funsor.Bint[self.subsample_size]}, self.subsample_size, ) super(plate, self).__init__(None) def __enter__(self): super().__enter__() # do this first to take care of globals recycling - name_to_dim = ( - OrderedDict([(self.name, self.dim)]) - if self.dim is not None - else OrderedDict() - ) + name_to_dim = {self.name: self.dim} if self.dim is not None else {} indices = to_data( self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE ) @@ -594,9 +590,7 @@ def process_message(self, msg): size = msg["fn"].enumerate_support(expand=False).shape[0] raw_value = jnp.arange(0, size) - funsor_value = funsor.Tensor( - raw_value, OrderedDict([(msg["name"], funsor.Bint[size])]), size - ) + funsor_value = funsor.Tensor(raw_value, {msg["name"]: funsor.Bint[size]}, size) msg["value"] = to_data(funsor_value) msg["done"] = True @@ -661,7 +655,7 @@ def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): :param x: An object. :param funsor.domains.Domain output: An optional output hint to uniquely convert a data to a Funsor (e.g. when `x` is a string). - :param OrderedDict dim_to_name: An optional mapping from negative + :param dict dim_to_name: An optional mapping from negative batch dimensions to name strings. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', @@ -669,7 +663,7 @@ def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): :return: A Funsor equivalent to `x`. :rtype: funsor.terms.Funsor """ - dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name + dim_to_name = {} if dim_to_name is None else dim_to_name initial_msg = { "type": "to_funsor", @@ -691,14 +685,14 @@ def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL): A primitive to extract a python object from a :class:`~funsor.terms.Funsor`. :param ~funsor.terms.Funsor x: A funsor object - :param OrderedDict name_to_dim: An optional inputs hint which maps + :param dict name_to_dim: An optional inputs hint which maps dimension names from `x` to dimension positions of the returned value. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', which can be used to interact with the global :class:`DimStack`. :return: A non-funsor equivalent to `x`. """ - name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim + name_to_dim = {} if name_to_dim is None else name_to_dim initial_msg = { "type": "to_data", diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 5e97f82d2..aefb81856 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict, defaultdict +from collections import defaultdict from contextlib import contextmanager import functools import re @@ -221,7 +221,7 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): dim_to_name = site["infer"]["dim_to_name"] - if all(dim == 1 for dim in log_prob.shape) and dim_to_name == OrderedDict(): + if all(dim == 1 for dim in log_prob.shape) and dim_to_name == {}: log_prob = log_prob.squeeze() log_prob_factor = funsor.to_funsor( diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index b58be1275..9e3818327 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections import OrderedDict, namedtuple -from typing import Any, Callable, OrderedDict as OrderedDictType, Union +from collections import defaultdict +from typing import Any, Callable, Union import jax from jax import random @@ -61,7 +61,7 @@ def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None def _find_slps( self, rng_key: jax.Array, *args: Any, **kwargs: Any - ) -> dict[str, OrderedDictType]: + ) -> dict[str, dict]: """ Discover the straight-line programs (SLPs) in the model by sampling from the prior. This implementation assumes that all branching is done via discrete sampling sites @@ -80,11 +80,11 @@ def _find_slps( return branching_traces - def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType: + def _get_branching_trace(self, tr: dict[str, Any]) -> dict: """ Extract the sites from the trace that are annotated with `infer={"branching": True}`. """ - branching_trace = OrderedDict() + branching_trace = {} for site in tr.values(): if ( site["type"] == "sample" @@ -109,7 +109,7 @@ def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType: def _run_inference( self, rng_key: jax.Array, - branching_trace: OrderedDictType, + branching_trace: dict, *args: Any, **kwargs: Any, ) -> RunInferenceResult: @@ -120,7 +120,7 @@ def _combine_inferences( self, rng_key: jax.Array, inferences: dict[str, Any], - branching_traces: dict[str, OrderedDictType], + branching_traces: dict[str, dict], *args: Any, **kwargs: Any, ) -> Union[DCCResult, SDVIResult]: @@ -139,7 +139,7 @@ def run( rng_key, subkey = random.split(rng_key) branching_traces = self._find_slps(subkey, *args, **kwargs) - inferences = dict() + inferences = {} for key, bt in branching_traces.items(): rng_key, subkey = random.split(rng_key) inferences[key] = self._run_inference(subkey, bt, *args, **kwargs) @@ -209,7 +209,7 @@ def __init__( def _run_inference( self, rng_key: jax.Array, - branching_trace: OrderedDictType, + branching_trace: dict, *args: Any, **kwargs: Any, ) -> RunInferenceResult: @@ -227,7 +227,7 @@ def _combine_inferences( # type: ignore[override] self, rng_key: jax.Array, samples: dict[str, Any], - branching_traces: dict[str, OrderedDictType], + branching_traces: dict[str, dict], *args: Any, **kwargs: Any, ) -> DCCResult: diff --git a/numpyro/contrib/stochastic_support/sdvi.py b/numpyro/contrib/stochastic_support/sdvi.py index 82a0299d7..b7b9769c6 100644 --- a/numpyro/contrib/stochastic_support/sdvi.py +++ b/numpyro/contrib/stochastic_support/sdvi.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, OrderedDict as OrderedDictType +from typing import Any, Callable import jax import jax.numpy as jnp @@ -98,7 +98,7 @@ def __init__( def _run_inference( self, rng_key: jax.Array, - branching_trace: OrderedDictType, + branching_trace: dict, *args: Any, **kwargs: Any, ) -> RunInferenceResult: @@ -121,7 +121,7 @@ def _combine_inferences( # type: ignore[override] self, rng_key: jax.Array, guides: dict[str, tuple[AutoGuide, dict[str, Any]]], - branching_traces: dict[str, OrderedDictType], + branching_traces: dict[str, dict], *args: Any, **kwargs: Any, ) -> SDVIResult: diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index 746289f2e..40985f109 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -5,7 +5,6 @@ This provides a small set of utilities in NumPyro that are used to diagnose posterior samples. """ -from collections import OrderedDict from itertools import product from typing import Union @@ -271,17 +270,15 @@ def summary( r_hat = split_gelman_rubin(value) hpd_lower = "{:.1f}%".format(50 * (1 - prob)) hpd_upper = "{:.1f}%".format(50 * (1 + prob)) - summary_dict[name] = OrderedDict( - [ - ("mean", mean), - ("std", std), - ("median", median), - (hpd_lower, hpd[0]), - (hpd_upper, hpd[1]), - ("n_eff", n_eff), - ("r_hat", r_hat), - ] - ) + summary_dict[name] = { + "mean": mean, + "std": std, + "median": median, + hpd_lower: hpd[0], + hpd_upper: hpd[1], + "n_eff": n_eff, + "r_hat": r_hat, + } return summary_dict diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7e67edad5..5d704fcc6 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -26,7 +26,6 @@ # POSSIBILITY OF SUCH DAMAGE. -from collections import OrderedDict from contextlib import contextmanager import functools import inspect @@ -620,7 +619,7 @@ def __init__( @staticmethod def _broadcast_shape( existing_shape: tuple[int, ...], new_shape: tuple[int, ...] - ) -> tuple[tuple[int, ...], OrderedDict, OrderedDict]: + ) -> tuple[tuple[int, ...], dict, dict]: if len(new_shape) < len(existing_shape): raise ValueError( "Cannot broadcast distribution of shape {} to shape {}".format( @@ -645,8 +644,8 @@ def _broadcast_shape( ) return ( tuple(reversed(reversed_shape)), - OrderedDict(reversed(expanded_sizes)), - OrderedDict(interstitial_sizes), + dict(reversed(expanded_sizes)), + dict(interstitial_sizes), ) @property diff --git a/numpyro/handlers.py b/numpyro/handlers.py index b57700b9c..7298d881f 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -93,7 +93,6 @@ def seeded_model(data): from __future__ import annotations -from collections import OrderedDict from types import TracebackType from typing import Callable, Optional, Union import warnings @@ -155,19 +154,19 @@ class trace(Messenger): >>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() >>> pp.pprint(exec_trace) # doctest: +SKIP - OrderedDict([('a', + {'a': {'args': (), 'fn': , 'is_observed': False, 'kwargs': {'rng_key': Array([0, 0], dtype=uint32)}, 'name': 'a', 'type': 'sample', - 'value': Array(-0.20584235, dtype=float32)})]) + 'value': Array(-0.20584235, dtype=float32)}} """ def __enter__(self) -> TraceT: # type: ignore [override] super(trace, self).__enter__() - self.trace: TraceT = OrderedDict() + self.trace: TraceT = {} return self.trace def postprocess_message(self, msg: Message) -> None: @@ -188,7 +187,7 @@ def get_trace(self, *args, **kwargs) -> TraceT: :param `*args`: arguments to the callable. :param `**kwargs`: keyword arguments to the callable. - :return: `OrderedDict` containing the execution trace. + :return: `dict` containing the execution trace. """ self(*args, **kwargs) return self.trace @@ -201,7 +200,7 @@ class replay(Messenger): values from the corresponding site names in `trace`. :param fn: Python callable with NumPyro primitives. - :param trace: an OrderedDict containing execution metadata. + :param trace: dict containing execution metadata. **Example:** diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index fc538ffa4..3031c7221 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -3,7 +3,7 @@ from __future__ import annotations -from collections import OrderedDict, defaultdict +from collections import defaultdict from collections.abc import Callable from functools import partial from typing import TYPE_CHECKING, Any, TypedDict, TypeVar @@ -1025,9 +1025,7 @@ def _partition( model_sum_deps: dict[str, frozenset[str]], sum_vars: frozenset[str] ) -> list[tuple[frozenset[str], frozenset[str]]]: # Construct a bipartite graph between model_sum_deps and the sum_vars - neighbors: OrderedDict[str, list[str]] = OrderedDict( - [(t, []) for t in model_sum_deps.keys()] - ) + neighbors: dict[str, list[str]] = {t: [] for t in model_sum_deps.keys()} for key, deps in model_sum_deps.items(): for dim in deps: if dim in sum_vars: @@ -1038,7 +1036,7 @@ def _partition( components = [] while neighbors: v, pending = neighbors.popitem() - component = OrderedDict([(v, None)]) # used as an OrderedSet + component = {v: None} # used as an OrderedSet for v in pending: component[v] = None while pending: diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index 8a28dd1f0..8253215db 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict, namedtuple +from collections import defaultdict from functools import partial import math import os @@ -94,7 +94,7 @@ def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): rng_keys = random.split(rng_key, len(mass_matrix_sqrt)) r = {} for (site_names, mm_sqrt), rng_key in zip(mass_matrix_sqrt.items(), rng_keys): - r_block = OrderedDict([(k, prototype_r[k]) for k in site_names]) + r_block = {k: prototype_r[k] for k in site_names} r.update(momentum_generator(r_block, mm_sqrt, rng_key)) return r diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index a21e7329e..3403cade5 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict, namedtuple +from collections import defaultdict import jax from jax import grad, jacfwd, random, value_and_grad, vmap @@ -1204,7 +1204,7 @@ def _euclidean_kinetic_energy_grad(inverse_mass_matrix, r): if isinstance(inverse_mass_matrix, dict): r_grad = {} for site_names, inverse_mm in inverse_mass_matrix.items(): - r_block = OrderedDict([(k, r[k]) for k in site_names]) + r_block = {k: r[k] for k in site_names} r_grad.update(_euclidean_kinetic_energy_grad(inverse_mm, r_block)) return r_grad diff --git a/numpyro/util.py b/numpyro/util.py index ef5929fed..72f7f88a0 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict from contextlib import contextmanager from functools import partial import inspect @@ -174,7 +173,7 @@ def identity(x, *args, **kwargs): def cached_by(outer_fn, *keys): # Restrict cache size to prevent ref cycles. max_size = 8 - outer_fn._cache = getattr(outer_fn, "_cache", OrderedDict()) + outer_fn._cache = getattr(outer_fn, "_cache", {}) def _wrapped(fn): fn_cache = outer_fn._cache diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index 0d403c59b..b162715ca 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict from functools import partial import numpy as np @@ -153,8 +152,8 @@ def _generate_data(): def test_iteration(): def testing(): for i in markov(range(5)): - v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), Bint[2])]), "real")) - v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real")) + v1 = to_data(Tensor(jnp.ones(2), {str(i): Bint[2]}, "real")) + v2 = to_data(Tensor(jnp.zeros(2), {"a": Bint[2]}, "real")) fv1 = to_funsor(v1, Real) fv2 = to_funsor(v2, Real) print(i, v1.shape) # shapes should alternate @@ -174,25 +173,23 @@ def testing(): def test_nesting(): def testing(): with markov(): - v1 = to_data(Tensor(jnp.ones(2), OrderedDict([("1", Bint[2])]), "real")) + v1 = to_data(Tensor(jnp.ones(2), {"1": Bint[2]}, "real")) print(1, v1.shape) # shapes should alternate assert v1.shape == (2,) with markov(): - v2 = to_data(Tensor(jnp.ones(2), OrderedDict([("2", Bint[2])]), "real")) + v2 = to_data(Tensor(jnp.ones(2), {"2": Bint[2]}, "real")) print(2, v2.shape) # shapes should alternate assert v2.shape == (2, 1) with markov(): - v3 = to_data( - Tensor(jnp.ones(2), OrderedDict([("3", Bint[2])]), "real") - ) + v3 = to_data(Tensor(jnp.ones(2), {"3": Bint[2]}, "real")) print(3, v3.shape) # shapes should alternate assert v3.shape == (2,) with markov(): v4 = to_data( - Tensor(jnp.ones(2), OrderedDict([("4", Bint[2])]), "real") + Tensor(jnp.ones(2), {"4": Bint[2]}, "real") ) print(4, v4.shape) # shapes should alternate @@ -206,9 +203,7 @@ def test_staggered(): def testing(): for i in markov(range(12)): if i % 4 == 0: - v2 = to_data( - Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real") - ) + v2 = to_data(Tensor(jnp.zeros(2), {"a": Bint[2]}, "real")) fv2 = to_funsor(v2, Real) assert v2.shape == (2,) print("a", v2.shape)