File tree Expand file tree Collapse file tree 4 files changed +333
-244
lines changed
Expand file tree Collapse file tree 4 files changed +333
-244
lines changed Load Diff Large diffs are not rendered by default.
Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ install_requires =
6969 captum == 0.5.0
7070 logomaker >= 0.8
7171 pyBigWig
72- ledidi < 2.0.0
72+ ledidi
7373 tangermeme >= 0.5.0
7474 memelite
7575 pygenomeviz <= 0.4.4
Original file line number Diff line number Diff line change @@ -247,9 +247,6 @@ def ledidi(
247247 """
248248 from ledidi import Ledidi
249249
250- # Add the prediction transform
251- model .add_transform (prediction_transform )
252-
253250 def loss_func (x , target ):
254251 return - Tensor (x ).mean ()
255252
@@ -264,30 +261,40 @@ def loss_func(x, target):
264261 else :
265262 input_mask = None
266263
264+ # Add the prediction transform
265+ model .add_transform (prediction_transform )
266+
267267 # Move model to device
268268 orig_device = model .device
269269 model = model .to (torch .device (devices ))
270+ model .eval ()
271+
272+ try :
273+ print ("Running Ledidi" )
274+
275+ # Initialize ledidi
276+ designer = Ledidi (
277+ model ,
278+ X [0 ].shape ,
279+ output_loss = loss_func ,
280+ max_iter = max_iter ,
281+ input_mask = input_mask ,
282+ target = None ,
283+ ** kwargs ,
284+ )
285+ designer = designer .to (torch .device (devices ))
270286
271- # Initialize ledidi
272- designer = Ledidi (
273- model ,
274- X [0 ].shape ,
275- output_loss = loss_func ,
276- max_iter = max_iter ,
277- input_mask = input_mask ,
278- target = None ,
279- ** kwargs ,
280- )
281- designer = designer .to (torch .device (devices ))
287+ # Run ledidi
288+ X_hat = designer .fit_transform (X , torch .tensor (0 )).cpu ()
282289
283- # Run ledidi
284- X_hat = designer . fit_transform ( X , None ). cpu ( )
290+ finally :
291+ print ( "Cleaning up model state..." )
285292
286- # Transfer device
287- model = model .to (orig_device )
293+ # Transfer device
294+ model = model .to (orig_device )
288295
289- # Remove the transform
290- model .reset_transform ()
296+ # Remove the transform
297+ model .reset_transform ()
291298
292299 # Return sequences as strings
293300 return convert_input_type (X_hat , "strings" )
Original file line number Diff line number Diff line change 22import pandas as pd
33from torch import Tensor , nn
44
5- from grelu .design import evolve
5+ from grelu .design import evolve , ledidi
66from grelu .lightning import LightningModel
77from grelu .transforms .prediction_transforms import Aggregate , Specificity
88from grelu .transforms .seq_transforms import PatternScore
9+ from grelu .sequence .format import check_string_dna
910
1011model = LightningModel (
1112 model_params = {
@@ -220,3 +221,19 @@ def test_evolve_7():
220221 assert np .all (output ["iter" ] == [0 , 1 , 2 ])
221222 assert np .all (output .seq == ['AT' , 'AC' , 'CC' ])
222223 assert np .all (output .label1 == [0.5 , 1.5 , 2. ])
224+
225+
226+ def test_ledidi ():
227+ output = ledidi (
228+ seq = 'GGTATTCATT' ,
229+ model = model ,
230+ prediction_transform = None ,
231+ max_iter = 20000 ,
232+ positions = None ,
233+ devices = "cpu" ,
234+ num_workers = 1 ,
235+ lr = 0.01 ,
236+ early_stopping_iter = 1000
237+ )
238+ assert isinstance (output , list )
239+ assert check_string_dna (output )
You can’t perform that action at this time.
0 commit comments