Skip to content

Commit 3e7cea5

Browse files
eodoleLarsKue
andauthored
Made Adapters Sliceable (#285)
* in test phase, print statements included but get item implemented, need to define set item, might be nice to list indecies of transforms in the default print * finished slice implementation for set item, noticed also that element transform was not added to the list for things to imported in the adapter file so added that * finished adapter slicing, i also made a test file but its not in the commits * ran linter * removed print statements * indexing added to print statements for adapter * made modifications in line with Lars comments and they passed the tests that I had previously written * removed unecessary comments2 * add basic list methods to Adapter --------- Co-authored-by: larskue <[email protected]>
1 parent 77dee84 commit 3e7cea5

File tree

1 file changed

+64
-7
lines changed

1 file changed

+64
-7
lines changed

bayesflow/adapters/adapter.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable, Sequence
1+
from collections.abc import Callable, MutableSequence, Sequence
22

33
import numpy as np
44
from keras.saving import (
@@ -26,17 +26,16 @@
2626
ToArray,
2727
Transform,
2828
)
29-
3029
from .transforms.filter_transform import Predicate
3130

3231

3332
@serializable(package="bayesflow.adapters")
34-
class Adapter:
33+
class Adapter(MutableSequence[Transform]):
3534
def __init__(self, transforms: Sequence[Transform] | None = None):
3635
if transforms is None:
3736
transforms = []
3837

39-
self.transforms = transforms
38+
self.transforms = list(transforms)
4039

4140
@staticmethod
4241
def create_default(inference_variables: Sequence[str]) -> "Adapter":
@@ -77,12 +76,70 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) ->
7776
return self.forward(data, **kwargs)
7877

7978
def __repr__(self):
80-
return f"Adapter([{' -> '.join(map(repr, self.transforms))}])"
79+
result = ""
80+
for i, transform in enumerate(self):
81+
result += f"{i}: {transform!r}"
82+
if i != len(self) - 1:
83+
result += " -> "
8184

82-
def add_transform(self, transform: Transform):
83-
self.transforms.append(transform)
85+
return f"Adapter([{result}])"
86+
87+
# list methods
88+
89+
def append(self, value: Transform) -> "Adapter":
90+
self.transforms.append(value)
8491
return self
8592

93+
def __delitem__(self, key: int | slice):
94+
del self.transforms[key]
95+
96+
def extend(self, values: Sequence[Transform]) -> "Adapter":
97+
if isinstance(values, Adapter):
98+
values = values.transforms
99+
100+
self.transforms.extend(values)
101+
102+
return self
103+
104+
def __getitem__(self, item: int | slice) -> "Adapter":
105+
if isinstance(item, int):
106+
return self.transforms[item]
107+
108+
return Adapter(self.transforms[item])
109+
110+
def insert(self, index: int, value: Transform | Sequence[Transform]) -> "Adapter":
111+
if isinstance(value, Adapter):
112+
value = value.transforms
113+
114+
if isinstance(value, Sequence):
115+
# convenience: Adapters are always flat
116+
self.transforms = self.transforms[:index] + list(value) + self.transforms[index:]
117+
else:
118+
self.transforms.insert(index, value)
119+
120+
return self
121+
122+
def __setitem__(self, key: int | slice, value: Transform | Sequence[Transform]) -> "Adapter":
123+
if isinstance(value, Adapter):
124+
value = value.transforms
125+
126+
if isinstance(key, int) and isinstance(value, Sequence):
127+
if key < 0:
128+
key += len(self.transforms)
129+
130+
key = slice(key, key + 1)
131+
132+
self.transforms[key] = value
133+
134+
return self
135+
136+
def __len__(self):
137+
return len(self.transforms)
138+
139+
# adapter methods
140+
141+
add_transform = append
142+
86143
def apply(
87144
self,
88145
*,

0 commit comments

Comments
 (0)