|
1 | 1 | from collections.abc import Sequence |
2 | 2 | import numpy as np |
3 | 3 |
|
4 | | -from keras.saving import ( |
5 | | - deserialize_keras_object as deserialize, |
6 | | - register_keras_serializable as serializable, |
7 | | - serialize_keras_object as serialize, |
8 | | -) |
| 4 | +from bayesflow.utils.serialization import serialize, serializable |
9 | 5 |
|
10 | 6 | from .transform import Transform |
11 | 7 |
|
12 | 8 |
|
13 | | -@serializable(package="bayesflow.adapters") |
| 9 | +@serializable |
14 | 10 | class Broadcast(Transform): |
15 | 11 | """ |
16 | 12 | Broadcasts arrays or scalars to the shape of a given other array. |
@@ -96,31 +92,15 @@ def __init__( |
96 | 92 | self.exclude = exclude |
97 | 93 | self.squeeze = squeeze |
98 | 94 |
|
99 | | - @classmethod |
100 | | - def from_config(cls, config: dict, custom_objects=None) -> "Broadcast": |
101 | | - # Deserialize turns tuples to lists, undo it if necessary |
102 | | - exclude = deserialize(config["exclude"], custom_objects) |
103 | | - exclude = tuple(exclude) if isinstance(exclude, list) else exclude |
104 | | - expand = deserialize(config["expand"], custom_objects) |
105 | | - expand = tuple(expand) if isinstance(expand, list) else expand |
106 | | - squeeze = deserialize(config["squeeze"], custom_objects) |
107 | | - squeeze = tuple(squeeze) if isinstance(squeeze, list) else squeeze |
108 | | - return cls( |
109 | | - keys=deserialize(config["keys"], custom_objects), |
110 | | - to=deserialize(config["to"], custom_objects), |
111 | | - expand=expand, |
112 | | - exclude=exclude, |
113 | | - squeeze=squeeze, |
114 | | - ) |
115 | | - |
116 | 95 | def get_config(self) -> dict: |
117 | | - return { |
118 | | - "keys": serialize(self.keys), |
119 | | - "to": serialize(self.to), |
120 | | - "expand": serialize(self.expand), |
121 | | - "exclude": serialize(self.exclude), |
122 | | - "squeeze": serialize(self.squeeze), |
| 96 | + config = { |
| 97 | + "keys": self.keys, |
| 98 | + "to": self.to, |
| 99 | + "expand": self.expand, |
| 100 | + "exclude": self.exclude, |
| 101 | + "squeeze": self.squeeze, |
123 | 102 | } |
| 103 | + return serialize(config) |
124 | 104 |
|
125 | 105 | # noinspection PyMethodOverriding |
126 | 106 | def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]: |
|
0 commit comments