Skip to content

Commit a80011c

Browse files
committed
Move part of nested dictionary operations to dict_utils
1 parent 23c1dd8 commit a80011c

File tree

4 files changed

+17
-16
lines changed

4 files changed

+17
-16
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bayesflow.adapters import Adapter
1212
from bayesflow.networks import InferenceNetwork, SummaryNetwork
1313
from bayesflow.types import Tensor
14-
from bayesflow.utils import filter_kwargs, logging, split_arrays
14+
from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict
1515
from .approximator import Approximator
1616

1717

@@ -232,16 +232,10 @@ def estimate(
232232
for variable_name in samples.keys()
233233
}
234234

235-
def squeeze_dict(d):
236-
if len(d.keys()) == 1 and "value" in d.keys():
237-
return d["value"]
238-
else:
239-
return d
240-
241235
# remove unnecessary nesting
242236
estimates = {
243237
variable_name: {
244-
outer_key: squeeze_dict(estimates[variable_name][outer_key])
238+
outer_key: squeeze_inner_estimates_dict(estimates[variable_name][outer_key])
245239
for outer_key in estimates[variable_name].keys()
246240
}
247241
for variable_name in estimates.keys()

bayesflow/approximators/point_approximator.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66

77
from bayesflow.types import Tensor
8-
from bayesflow.utils import filter_kwargs, split_arrays
8+
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict
99
from .continuous_approximator import ContinuousApproximator
1010

1111

@@ -62,16 +62,10 @@ def estimate(
6262
for variable_name in inference_variable_names
6363
}
6464

65-
def squeeze_dict(d):
66-
if len(d.keys()) == 1 and "value" in d.keys():
67-
return d["value"]
68-
else:
69-
return d
70-
7165
# remove unnecessary nesting
7266
conditions = {
7367
variable_name: {
74-
outer_key: squeeze_dict(conditions[variable_name][outer_key])
68+
outer_key: squeeze_inner_estimates_dict(conditions[variable_name][outer_key])
7569
for outer_key in conditions[variable_name].keys()
7670
}
7771
for variable_name in conditions.keys()

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
keras_kwargs,
1313
split_tensors,
1414
split_arrays,
15+
squeeze_inner_estimates_dict,
1516
)
1617
from .dispatch import find_distribution, find_network, find_permutation, find_pooling, find_recurrent_net
1718
from .ecdf import simultaneous_ecdf_bands, ranks

bayesflow/utils/dict_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,15 @@ def dicts_to_arrays(
318318
targets=targets,
319319
priors=priors,
320320
)
321+
322+
323+
def squeeze_inner_estimates_dict(estimates):
324+
"""If a dictionary has only one key-value pair and the key is "value", return only its value.
325+
Otherwise, return the unchanged dictionary.
326+
327+
This method helps to remove unnecessary nesting levels.
328+
"""
329+
if len(estimates.keys()) == 1 and "value" in estimates.keys():
330+
return estimates["value"]
331+
else:
332+
return estimates

0 commit comments

Comments
 (0)