Skip to content

Commit 2c30964

Browse files
authored
Merge pull request #150 from Genentech/update-ledidi
fixed ledidi code to work with latest version and reran design tutorial
2 parents 93c7ae2 + 84d2a89 commit 2c30964

File tree

4 files changed

+333
-244
lines changed

4 files changed

+333
-244
lines changed

docs/tutorials/4_design.ipynb

Lines changed: 287 additions & 222 deletions
Large diffs are not rendered by default.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ install_requires =
6969
captum == 0.5.0
7070
logomaker >= 0.8
7171
pyBigWig
72-
ledidi < 2.0.0
72+
ledidi
7373
tangermeme >= 0.5.0
7474
memelite
7575
pygenomeviz <= 0.4.4

src/grelu/design.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,6 @@ def ledidi(
247247
"""
248248
from ledidi import Ledidi
249249

250-
# Add the prediction transform
251-
model.add_transform(prediction_transform)
252-
253250
def loss_func(x, target):
254251
return -Tensor(x).mean()
255252

@@ -264,30 +261,40 @@ def loss_func(x, target):
264261
else:
265262
input_mask = None
266263

264+
# Add the prediction transform
265+
model.add_transform(prediction_transform)
266+
267267
# Move model to device
268268
orig_device = model.device
269269
model = model.to(torch.device(devices))
270+
model.eval()
271+
272+
try:
273+
print("Running Ledidi")
274+
275+
# Initialize ledidi
276+
designer = Ledidi(
277+
model,
278+
X[0].shape,
279+
output_loss=loss_func,
280+
max_iter=max_iter,
281+
input_mask=input_mask,
282+
target=None,
283+
**kwargs,
284+
)
285+
designer = designer.to(torch.device(devices))
270286

271-
# Initialize ledidi
272-
designer = Ledidi(
273-
model,
274-
X[0].shape,
275-
output_loss=loss_func,
276-
max_iter=max_iter,
277-
input_mask=input_mask,
278-
target=None,
279-
**kwargs,
280-
)
281-
designer = designer.to(torch.device(devices))
287+
# Run ledidi
288+
X_hat = designer.fit_transform(X, torch.tensor(0)).cpu()
282289

283-
# Run ledidi
284-
X_hat = designer.fit_transform(X, None).cpu()
290+
finally:
291+
print("Cleaning up model state...")
285292

286-
# Transfer device
287-
model = model.to(orig_device)
293+
# Transfer device
294+
model = model.to(orig_device)
288295

289-
# Remove the transform
290-
model.reset_transform()
296+
# Remove the transform
297+
model.reset_transform()
291298

292299
# Return sequences as strings
293300
return convert_input_type(X_hat, "strings")

tests/test_design.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import pandas as pd
33
from torch import Tensor, nn
44

5-
from grelu.design import evolve
5+
from grelu.design import evolve, ledidi
66
from grelu.lightning import LightningModel
77
from grelu.transforms.prediction_transforms import Aggregate, Specificity
88
from grelu.transforms.seq_transforms import PatternScore
9+
from grelu.sequence.format import check_string_dna
910

1011
model = LightningModel(
1112
model_params={
@@ -220,3 +221,19 @@ def test_evolve_7():
220221
assert np.all(output["iter"] == [0, 1, 2])
221222
assert np.all(output.seq == ['AT', 'AC', 'CC'])
222223
assert np.all(output.label1 == [0.5, 1.5, 2.])
224+
225+
226+
def test_ledidi():
227+
output = ledidi(
228+
seq='GGTATTCATT',
229+
model= model,
230+
prediction_transform = None,
231+
max_iter = 20000,
232+
positions = None,
233+
devices = "cpu",
234+
num_workers = 1,
235+
lr=0.01,
236+
early_stopping_iter=1000
237+
)
238+
assert isinstance(output, list)
239+
assert check_string_dna(output)

0 commit comments

Comments
 (0)