Skip to content

Commit 9166a7b

Browse files
committed
Refactor links
* rename link_functions to links * separated activation function Ordered from OrderedQuantiles, one for generality, the other for automatic smart anchor selection based on quantile levels * introduce link function for learnable positive semi-definite matrices * link tests
1 parent ecef8ed commit 9166a7b

File tree

10 files changed

+191
-64
lines changed

10 files changed

+191
-64
lines changed

bayesflow/link_functions/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

bayesflow/link_functions/ordered_quantiles.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

bayesflow/links/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .ordered import Ordered
2+
from .ordered_quantiles import OrderedQuantiles
3+
from .positive_semi_definite import PositiveSemiDefinite

bayesflow/links/ordered.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import keras
2+
3+
from bayesflow.utils import keras_kwargs
4+
5+
6+
class Ordered(keras.Layer):
7+
def __init__(self, axis: int, anchor_index: int, **kwargs):
8+
super().__init__(**keras_kwargs(kwargs))
9+
self.axis = axis
10+
self.anchor_index = anchor_index
11+
12+
def build(self, input_shape):
13+
super().build(input_shape)
14+
print("build Ordered()")
15+
16+
self.group_indeces = dict(
17+
below=list(range(0, self.anchor_index)),
18+
above=list(range(self.anchor_index + 1, input_shape[self.axis])),
19+
)
20+
21+
def call(self, inputs):
22+
# Divide in anchor, below and above
23+
below_inputs = keras.ops.take(inputs, self.group_indeces["below"], axis=self.axis)
24+
anchor_input = keras.ops.take(inputs, self.anchor_index, axis=self.axis)
25+
anchor_input = keras.ops.expand_dims(anchor_input, axis=self.axis)
26+
above_inputs = keras.ops.take(inputs, self.group_indeces["above"], axis=self.axis)
27+
28+
# Apply softplus for positivity and cumulate to ensure ordered quantiles
29+
below = keras.activations.softplus(below_inputs)
30+
above = keras.activations.softplus(above_inputs)
31+
32+
below = anchor_input - keras.ops.flip(keras.ops.cumsum(below, axis=self.axis), self.axis)
33+
above = anchor_input + keras.ops.cumsum(above, axis=self.axis)
34+
35+
# Concatenate and reshape back
36+
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
37+
return x
38+
39+
def compute_output_shape(self, input_shape):
40+
return input_shape
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import keras
2+
3+
from bayesflow.utils import keras_kwargs
4+
5+
from collections.abc import Sequence
6+
7+
from .ordered import Ordered
8+
9+
10+
class OrderedQuantiles(Ordered):
11+
def __init__(self, q: Sequence[float] = None, axis: int = None, anchor_index: int = None, **kwargs):
12+
super().__init__(axis, anchor_index, **keras_kwargs(kwargs))
13+
self.q = q
14+
15+
def build(self, input_shape):
16+
if 1 < len(input_shape) <= 3:
17+
self.axis = -2
18+
if self.q is None:
19+
# choose the middle of the specified axis as anchor index
20+
num_quantile_levels = input_shape[self.axis]
21+
self.anchor_index = num_quantile_levels // 2
22+
else:
23+
# choose quantile level closest to median as anchor index
24+
self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5))
25+
26+
super().build(input_shape)
27+
28+
else:
29+
raise AssertionError(
30+
f"Cannot resolve which axis should be ordered automatically from input shape {input_shape}."
31+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import keras
2+
3+
from bayesflow.utils import keras_kwargs
4+
5+
6+
class PositiveSemiDefinite(keras.Layer):
7+
def __init__(self, **kwargs):
8+
super().__init__(**keras_kwargs(kwargs))
9+
10+
def call(self, inputs):
11+
return keras.ops.einsum("...ij,...kj->...ik", inputs, inputs)
12+
13+
def compute_output_shape(self, input_shape):
14+
return input_shape

tests/test_links/__init__.py

Whitespace-only changes.

tests/test_links/conftest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import keras
3+
import pytest
4+
5+
6+
@pytest.fixture()
7+
def batch_size():
8+
return 16
9+
10+
11+
@pytest.fixture()
12+
def num_quantiles():
13+
return 19
14+
15+
16+
@pytest.fixture()
17+
def quantiles_np(num_quantiles):
18+
return np.linspace(0, 1, num_quantiles + 2)[1:-1]
19+
20+
21+
@pytest.fixture()
22+
def quantiles_py(quantiles_np):
23+
return list(quantiles_np)
24+
25+
26+
@pytest.fixture()
27+
def quantiles_keras(quantiles_np):
28+
return keras.ops.convert_to_tensor(quantiles_np)
29+
30+
31+
@pytest.fixture()
32+
def none():
33+
return None
34+
35+
36+
@pytest.fixture(params=["quantiles_np", "quantiles_py", "quantiles_keras", "none"], scope="function")
37+
def quantiles(request):
38+
return request.getfixturevalue(request.param)
39+
40+
41+
@pytest.fixture()
42+
def num_variables():
43+
return 10
44+
45+
46+
@pytest.fixture()
47+
def unordered(batch_size, num_quantiles, num_variables):
48+
return keras.random.normal((batch_size, num_quantiles, num_variables))
49+
50+
51+
@pytest.fixture()
52+
def random_matrix_batch(batch_size, num_variables):
53+
return keras.random.normal((batch_size, num_variables, num_variables))

tests/test_links/test_links.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
import pytest
3+
4+
5+
def check_ordering(output, axis):
6+
assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}."
7+
for i in range(output.ndim):
8+
if i != axis % output.ndim:
9+
assert not np.all(
10+
np.diff(output, axis=i) > 0
11+
), f"is ordered along axis which is not meant to be ordered: {i}."
12+
13+
14+
@pytest.mark.parametrize("axis", [0, 1, 2])
15+
def test_ordering(axis, unordered):
16+
from bayesflow.links import Ordered
17+
18+
activation = Ordered(axis=axis, anchor_index=5)
19+
20+
output = activation(unordered)
21+
22+
check_ordering(output, axis)
23+
24+
25+
def test_quantile_ordering(quantiles, unordered):
26+
from bayesflow.links import OrderedQuantiles
27+
28+
activation = OrderedQuantiles(q=None)
29+
30+
activation.build(unordered.shape)
31+
axis = activation.axis
32+
33+
output = activation(unordered)
34+
35+
check_ordering(output, axis)
36+
37+
38+
def test_positive_semi_definite(random_matrix_batch):
39+
from bayesflow.links import PositiveSemiDefinite
40+
41+
activation = PositiveSemiDefinite()
42+
43+
output = activation(random_matrix_batch)
44+
45+
eigenvalues = np.linalg.eig(output).eigenvalues
46+
47+
assert np.all(eigenvalues.real > 0) and np.all(
48+
np.isclose(eigenvalues.imag, 0)
49+
), "output is not positive semi-definite."

tests/test_scores/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def batch_size():
99

1010
@pytest.fixture()
1111
def num_variables():
12-
return 4
12+
return 10
1313

1414

1515
@pytest.fixture()

0 commit comments

Comments
 (0)