File tree Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -32,8 +32,8 @@ def get_config(self) -> dict:
3232 }
3333 return serialize (config )
3434
35- def forward (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
35+ def forward (self , data : np .ndarray | dict , ** kwargs ) -> np .ndarray | dict :
3636 return map_structure (lambda d : d .astype (self .to_dtype , copy = False ), data )
3737
38- def inverse (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
38+ def inverse (self , data : np .ndarray | dict , ** kwargs ) -> np .ndarray | dict :
3939 return map_structure (lambda d : d .astype (self .from_dtype , copy = False ), data )
Original file line number Diff line number Diff line change 11import optree
2+ from typing import Callable
23
34
45def flatten_shape (structure ):
@@ -14,7 +15,7 @@ def is_shape_tuple(x):
1415 return leaves
1516
1617
17- def map_dict (func , dictionary ) :
18+ def map_dict (func : Callable , dictionary : dict ) -> dict :
1819 """Applies a function to all leaves of a (possibly nested) dictionary.
1920
2021 Parameters
You can’t perform that action at this time.
0 commit comments