Skip to content

Commit 608a6f4

Browse files
committed
add and adapt type hints
1 parent 2f78f65 commit 608a6f4

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

bayesflow/adapters/transforms/convert_dtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def get_config(self) -> dict:
3232
}
3333
return serialize(config)
3434

35-
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
35+
def forward(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict:
3636
return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data)
3737

38-
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
38+
def inverse(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict:
3939
return map_structure(lambda d: d.astype(self.from_dtype, copy=False), data)

bayesflow/utils/tree.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import optree
2+
from typing import Callable
23

34

45
def flatten_shape(structure):
@@ -14,7 +15,7 @@ def is_shape_tuple(x):
1415
return leaves
1516

1617

17-
def map_dict(func, dictionary):
18+
def map_dict(func: Callable, dictionary: dict) -> dict:
1819
"""Applies a function to all leaves of a (possibly nested) dictionary.
1920
2021
Parameters

0 commit comments

Comments
 (0)