-
Notifications
You must be signed in to change notification settings - Fork 78
[WIP] Move standardization into approximators and make adapter stateless. #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 27 commits
ceab303
d79b17a
c777122
7aeb9cb
45ab9ea
a83770a
4df270a
8ea6782
b2a4f76
deffc27
43af4bd
039fc8d
02ded97
54d860e
2a86cc3
1df9269
49af469
1fdde32
0869e3f
1a845e3
bd2725d
c5fb949
100d7c0
905bf05
38f2228
0c24db2
5755135
4fa1bbb
b2bfeea
5773d28
392d9f7
2d5b2fb
bde587c
1b2b5be
d406a29
a503bd9
caf0491
dd24941
8268128
e32ae2e
b7d6c0e
00d72ab
c2ebd23
cd45b85
3f28f34
0952a29
5c529a2
399a1b4
dd0dc87
d28df75
40d2d1d
c6d79ae
df1761b
3b93251
65cac46
1944186
1ebf1cd
71cd6b9
3f0f9d1
a3b59c3
183f608
f0de38b
43ced5b
c3e945e
82e28a7
ef97a6c
24c268b
e45f260
333c30f
48bb190
fd83567
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,17 +11,17 @@ | |
|
|
||
|
|
||
| class Approximator(BackendApproximator): | ||
| def build(self, data_shapes: any) -> None: | ||
| mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes) | ||
| def build(self, data_shapes: dict[str, tuple[int]]) -> None: | ||
|
||
| mock_data = keras.tree.map_shape_structure(keras.ops.zeros, data_shapes) | ||
| self.build_from_data(mock_data) | ||
|
|
||
| @classmethod | ||
| def build_adapter(cls, **kwargs) -> Adapter: | ||
| # implemented by each respective architecture | ||
| raise NotImplementedError | ||
|
|
||
| def build_from_data(self, data: dict[str, any]) -> None: | ||
| self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training") | ||
| def build_from_data(self, adapted_data: dict[str, any]) -> None: | ||
| self.compute_metrics(**filter_kwargs(adapted_data, self.compute_metrics), stage="training") | ||
|
||
| self.built = True | ||
|
|
||
| @classmethod | ||
|
|
@@ -61,6 +61,9 @@ | |
| max_queue_size=max_queue_size, | ||
| ) | ||
|
|
||
| def call(self, *args, **kwargs): | ||
| return self.compute_metrics(*args, **kwargs) | ||
|
|
||
| def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs): | ||
| """ | ||
| Trains the approximator on the provided dataset or on-demand data generated from the given simulator. | ||
|
|
@@ -132,6 +135,7 @@ | |
| logging.info("Building on a test batch.") | ||
| mock_data = dataset[0] | ||
| mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data) | ||
| self.build_from_data(mock_data) | ||
| mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data) | ||
| self.build(mock_data_shapes) | ||
|
|
||
| return super().fit(dataset=dataset, **kwargs) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to have a convenience function that calculates mean and std for a dataset, in the format that would be required here. We could also advertise it in the deprecation warning. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but such a function will not be very efficient when the entire data set is not (yet) in memory. I see its use mainly for
OfflineDataset.