Skip to content

Commit 62f9ab1

Browse files
authored
Merge pull request #111 from Genentech/transform-ensembles
Transform ensembles
2 parents f23dd76 + 39368ff commit 62f9ab1

File tree

3 files changed

+58
-13
lines changed

3 files changed

+58
-13
lines changed

src/grelu/lightning/__init__.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,12 @@ def __init__(self, models: list, model_names: Optional[List[str]] = None) -> Non
11461146
"n_tasks": sum([model.model_params["n_tasks"] for model in self.models])
11471147
}
11481148
self.data_params = {"tasks": defaultdict(list)}
1149+
1150+
# Set models to eval mode (since this class is used for prediction and design)
1151+
for model in self.models:
1152+
model.eval()
1153+
1154+
self.reset_transform()
11491155
self._combine_tasks()
11501156

11511157
def _combine_tasks(self) -> None:
@@ -1171,18 +1177,51 @@ def forward(self, x: Tensor) -> Tensor:
11711177
"""
11721178
Forward Pass.
11731179
"""
1174-
return torch.cat([model(x) for model in self.models], axis=1) # B, T, L
1180+
x = torch.cat([model(x) for model in self.models], axis=1) # B, T, L
1181+
1182+
# apply transform to ensemble output
1183+
x = self.transform(x)
1184+
return x
11751185

1176-
def predict_on_dataset(self, dataset: Callable, **kwargs) -> np.ndarray:
1186+
def add_transform(self, prediction_transform: Callable) -> None:
11771187
"""
1178-
This will return the concatenated predictions from all the
1179-
constituent models, in the order in which they were supplied.
1180-
Predictions will be concatenated along the task axis.
1188+
Add a prediction transform
1189+
"""
1190+
if prediction_transform is not None:
1191+
self.transform = prediction_transform
1192+
1193+
def reset_transform(self) -> None:
1194+
"""
1195+
Remove a prediction transform
11811196
"""
1182-
return np.concatenate(
1197+
self.transform = nn.Identity()
1198+
1199+
def predict_on_dataset(
1200+
self,
1201+
dataset: Callable,
1202+
**kwargs,
1203+
):
1204+
"""
1205+
Predict for a dataset of sequences or variants. This will return
1206+
the concatenated predictions from all the constituent models, in the
1207+
order in which they were supplied to __.init__. Predictions will be
1208+
concatenated along the task axis.
1209+
1210+
Args:
1211+
dataset: Dataset object that yields one-hot encoded sequences
1212+
**kwargs: Additional arguments to pass to the `predict_on_dataset`
1213+
functions of the constituent models.
1214+
1215+
Returns:
1216+
Model predictions as a numpy array
1217+
"""
1218+
preds = np.concatenate(
11831219
[model.predict_on_dataset(dataset, **kwargs) for model in self.models],
11841220
axis=-2,
11851221
)
1222+
if not isinstance(self.transform, nn.Identity):
1223+
preds = self.transform(torch.tensor(preds)).numpy()
1224+
return preds
11861225

11871226
def get_task_idxs(
11881227
self, tasks: Union[str, int, List[str], List[int]], key: str = "name"

src/grelu/transforms/prediction_transforms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ def filter(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]:
9797
"""
9898
# Select positions
9999
if self.positions is not None:
100-
x = x[:, :, self.positions]
100+
x = x[..., self.positions]
101101

102102
# Select tasks
103103
if self.tasks is not None:
104-
x = x[:, self.tasks, :]
104+
x = x[..., self.tasks, :]
105105
elif self.except_tasks is not None:
106106
keep = [i for i in range(x.shape[1]) if i not in self.except_tasks]
107-
x = x[:, keep, :]
107+
x = x[..., keep, :]
108108
return x
109109

110110
def torch_aggregate(self, x: Tensor) -> Tensor:
@@ -113,11 +113,11 @@ def torch_aggregate(self, x: Tensor) -> Tensor:
113113
"""
114114
# Aggregate positions
115115
if self.length_aggfunc is not None:
116-
x = self.length_aggfunc(x, axis=2, keepdims=True)
116+
x = self.length_aggfunc(x, axis=-1, keepdims=True)
117117

118118
# Aggregate tasks
119119
if self.task_aggfunc is not None:
120-
x = self.task_aggfunc(x, axis=1, keepdims=True)
120+
x = self.task_aggfunc(x, axis=-2, keepdims=True)
121121
return x
122122

123123
def numpy_aggregate(self, x: np.ndarray) -> np.ndarray:
@@ -126,11 +126,11 @@ def numpy_aggregate(self, x: np.ndarray) -> np.ndarray:
126126
"""
127127
# Aggregate positions
128128
if self.length_aggfunc is not None:
129-
x = self.length_aggfunc_numpy(x, axis=2, keepdims=True)
129+
x = self.length_aggfunc_numpy(x, axis=-1, keepdims=True)
130130

131131
# Aggregate tasks
132132
if self.task_aggfunc is not None:
133-
x = self.task_aggfunc_numpy(x, axis=1, keepdims=True)
133+
x = self.task_aggfunc_numpy(x, axis=-2, keepdims=True)
134134

135135
return x
136136

tests/test_lightning.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,12 @@ def test_lightning_model_ensemble():
371371
preds = model.predict_on_dataset(dataset=udataset, devices="cpu")
372372
assert preds.shape == (2, 4, 1)
373373

374+
# Test transform
375+
t = Aggregate(task_aggfunc="mean")
376+
model.add_transform(t)
377+
preds = model.predict_on_dataset(dataset=udataset, devices="cpu")
378+
assert preds.shape == (2, 1, 1)
379+
374380

375381
bin_model = generate_model(task="binary", loss="bce", n_tasks=2)
376382
bin_model.model_params["crop_len"] = 0

0 commit comments

Comments
 (0)