-
Notifications
You must be signed in to change notification settings - Fork 78
Subset arrays #411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subset arrays #411
Changes from 9 commits
69e236d
9c0da4c
d57aee4
8d834da
6c1d503
2e83846
dee4534
71dc35a
6c34a5d
c3640cb
f17322f
5312c5f
5361c04
504344b
4218b70
7e3911b
350513f
415b658
37598b0
ee28392
5bbf44a
676c19f
f261b50
87017e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,3 +39,6 @@ docs/ | |
|
|
||
| # MacOS | ||
| .DS_Store | ||
|
|
||
| # Rproj | ||
| .Rproj.user | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,8 @@ | |
| Standardize, | ||
| ToArray, | ||
| Transform, | ||
| SubsampleArray, | ||
| Take | ||
| ) | ||
| from .transforms.filter_transform import Predicate | ||
|
|
||
|
|
@@ -541,6 +543,42 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int): | |
| transform = MapTransform({key: OneHot(num_classes=num_classes) for key in keys}) | ||
| self.transforms.append(transform) | ||
| return self | ||
|
|
||
| def random_subsample(self, | ||
| keys: str | Sequence[str], | ||
| *, | ||
| sample_size: int, | ||
| axis: int=-1, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Append a :py:class:`~transforms.SubsampleArray` transform to the adapter. | ||
eodole marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| predicate : Predicate, optional | ||
| Function that indicates which variables should be transformed. | ||
| include : str or Sequence of str, optional | ||
| Names of variables to include in the transform. | ||
| exclude : str or Sequence of str, optional | ||
| Names of variables to exclude from the transform. | ||
| **kwargs : dict | ||
| Additional keyword arguments passed to the transform. | ||
|
|
||
| """ | ||
| if isinstance(keys, str): | ||
| keys = [keys] | ||
|
|
||
| transform = MapTransform( | ||
| transform_map={ | ||
| key: SubsampleArray(sample_size=sample_size, axis=axis) | ||
| for key in keys | ||
| } | ||
|
|
||
| ) | ||
|
|
||
| self.transforms.append(transform) | ||
| return self | ||
|
|
||
| def rename(self, from_key: str, to_key: str): | ||
| """Append a :py:class:`~transforms.Rename` transform to the adapter. | ||
|
|
@@ -601,6 +639,36 @@ def standardize( | |
| self.transforms.append(transform) | ||
| return self | ||
|
|
||
| def take(self, | ||
| *, | ||
| predicate: Predicate = None, | ||
| include: str | Sequence[str] = None, | ||
| exclude: str | Sequence[str] = None, | ||
| **kwargs,): | ||
| """ | ||
| Append a :py:class:`~transforms.Take` transform to the adapter. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| predicate : Predicate, optional | ||
| Function that indicates which variables should be transformed. | ||
| include : str or Sequence of str, optional | ||
| Names of variables to include in the transform. | ||
| exclude : str or Sequence of str, optional | ||
| Names of variables to exclude from the transform. | ||
| **kwargs : dict | ||
| Additional keyword arguments passed to the transform. """ | ||
| transform = FilterTransform( | ||
| transform_constructor=Take, | ||
| predicate=predicate, | ||
| include=include, | ||
| exclude=exclude, | ||
| **kwargs, | ||
| ) | ||
| self.transforms.append(transform) | ||
| return self | ||
|
|
||
|
|
||
| def to_array( | ||
| self, | ||
| *, | ||
|
|
@@ -631,3 +699,9 @@ def to_array( | |
| ) | ||
| self.transforms.append(transform) | ||
| return self | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These empty lines would be removed by the formatter, which should automatically run on pre-commit.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i made a new environment based on the contribution.md
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try running |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| import numpy as np | ||
| from keras.saving import register_keras_serializable as serializable | ||
|
|
||
| from .elementwise_transform import ElementwiseTransform | ||
|
|
||
|
|
||
| @serializable(package="bayesflow.adapters") | ||
| class SubsampleArray(ElementwiseTransform): | ||
| """ | ||
| A transform that takes a random subsample of the data within an axis. | ||
|
|
||
| Example: adapter.subsample("x", sample_size = 3, axis = -1) | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| sample_size: int, | ||
| axis: int = -1, | ||
| ): | ||
| super().__init__() | ||
| self.sample_size = sample_size | ||
| self.axis = axis | ||
|
|
||
| def forward(self, data: np.ndarray): | ||
| sample_size = self.sample_size | ||
| axis = self.axis | ||
|
|
||
| max_sample_size = data.shape[axis] | ||
|
|
||
| sample_indices = np.random.permutation(max_sample_size)[ | ||
| 0 : sample_size - 1 | ||
| ] # random sample without replacement | ||
LarsKue marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return np.take(data, sample_indices, axis) | ||
|
|
||
| def inverse(self, data, **kwargs): | ||
| # non invertible transform | ||
| return data | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| import numpy as np | ||
| from keras.saving import register_keras_serializable as serializable | ||
|
|
||
| from .elementwise_transform import ElementwiseTransform | ||
|
|
||
|
|
||
| @serializable(package="bayesflow.adapters") | ||
| class Take(ElementwiseTransform): | ||
| """ | ||
| A transform to reduce the dimensionality of arrays output by the summary network | ||
| Axis is a mandatory argument and will default to the last axis. | ||
| Example: adapter.take("x", np.arange(0,3), axis=-1) | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, data, indices, axis=-1): | ||
eodole marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return np.take(data, indices, axis) | ||
|
|
||
| def inverse(self, data): | ||
| # not a true invertible function | ||
| return data | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,9 +30,9 @@ def check_ordering(output, axis): | |
| assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}." | ||
| for i in range(output.ndim): | ||
| if i != axis % output.ndim: | ||
| assert not np.all(np.diff(output, axis=i) > 0), ( | ||
| f"is ordered along axis which is not meant to be ordered: {i}." | ||
| ) | ||
| assert not np.all( | ||
| np.diff(output, axis=i) > 0 | ||
| ), f"is ordered along axis which is not meant to be ordered: {i}." | ||
|
||
|
|
||
|
|
||
| @pytest.mark.parametrize("axis", [0, 1, 2]) | ||
|
|
@@ -69,6 +69,6 @@ def test_positive_semi_definite(random_matrix_batch): | |
| output = keras.ops.convert_to_numpy(output) | ||
| eigenvalues = np.linalg.eig(output).eigenvalues | ||
|
|
||
| assert np.all(eigenvalues.real > 0) and np.all(np.isclose(eigenvalues.imag, 0)), ( | ||
| f"output is not positive semi-definite: real={eigenvalues.real}, imag={eigenvalues.imag}" | ||
| ) | ||
| assert np.all(eigenvalues.real > 0) and np.all( | ||
| np.isclose(eigenvalues.imag, 0) | ||
| ), f"output is not positive semi-definite: real={eigenvalues.real}, imag={eigenvalues.imag}" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am unfamiliar with R. What is this directory used for, and should all other users have it ignored too? Otherwise, please put this in your local
.git/info/excludeinstead.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to @stefanradev93, this should be
.Rproj