@@ -39,43 +39,31 @@ def unmasked_wt_score(
3939 train : bool = False ,
4040 cut_special_tokens : bool = True , # assumption: cut first and last token
4141 device = None ,
42- ** kwargs
42+ verbose : bool = False ,
43+ ** model_kwargs
4344 ):
4445 if device is None :
4546 device = get_device ()
4647 if wt_input_ids .dim () == 1 :
4748 wt_input_ids = wt_input_ids .unsqueeze (0 )
48- structure_input_ids = kwargs .get ("structure_input_ids" , None )
49+ # structure_input_ids = model_kwargs .get("structure_input_ids", None)
4950
5051 attention_masks = torch .Tensor (np .full (
5152 shape = np .shape (wt_input_ids ), fill_value = attention_mask )).to (torch .int64 )
5253 if train :
53- if structure_input_ids is not None :
54+ outputs = model (
55+ input_ids = wt_input_ids .to (device ),
56+ attention_mask = attention_masks .to (device ),
57+ ** model_kwargs
58+ )
59+
60+ else :
61+ with torch .no_grad ():
5462 outputs = model (
5563 input_ids = wt_input_ids .to (device ),
5664 attention_mask = attention_masks .to (device ),
57- ss_input_ids = structure_input_ids .to (device )
58- )
59- else :
60- outputs = model (
61- wt_input_ids .to (device ),
62- attention_masks .to (device ),
63- output_hidden_states = False
65+ ** model_kwargs
6466 )
65- else :
66- with torch .no_grad ():
67- if structure_input_ids is not None :
68- outputs = model (
69- input_ids = wt_input_ids .to (device ),
70- attention_mask = attention_masks .to (device ),
71- ss_input_ids = structure_input_ids .to (device )
72- )
73- else :
74- outputs = model (
75- wt_input_ids .to (device ),
76- attention_masks .to (device ),
77- output_hidden_states = False ,
78- )
7967
8068 logits = outputs .logits
8169 logits = logits .squeeze (0 ) # remove batch dim
@@ -105,7 +93,7 @@ def unmasked_wt_score(
10593 return log_probs
10694
10795
108- def esm_mutation_only_mutation_masked_pll (
96+ def mutation_only_mutation_masked_pll (
10997 tokenized_sequences : torch .Tensor , # (L,)
11098 wt_input_ids : torch .Tensor , # (L,)
11199 attention_mask : torch .Tensor , # (L,)
@@ -198,7 +186,7 @@ def esm_mutation_only_mutation_masked_pll(
198186 return plls
199187
200188
201- def esm_mutation_all_pos_masked_pll (
189+ def mutation_all_pos_masked_pll (
202190 tokenized_sequences : torch .Tensor , # (L,)
203191 attention_mask : torch .Tensor , # (L,)
204192 model ,
@@ -285,7 +273,7 @@ def plm_inference(
285273 wt_input_ids ,
286274 attention_mask ,
287275 model ,
288- mask_token_id ,
276+ mask_token_id = None ,
289277 inference_type = 'unmasked' ,
290278 wt_structure_input_ids = None ,
291279 batch_size = 5 ,
@@ -304,9 +292,9 @@ def plm_inference(
304292 if not isinstance (attention_mask , torch .Tensor ):
305293 attention_mask = torch .tensor (attention_mask , dtype = torch .long )
306294 if inference_type == 'mutation-masking' :
307- inference_function = esm_mutation_only_mutation_masked_pll
295+ inference_function = mutation_only_mutation_masked_pll
308296 elif inference_type in ['full-masking' , 'all-pos-masking' ]:
309- inference_function = esm_mutation_all_pos_masked_pll
297+ inference_function = mutation_all_pos_masked_pll
310298 elif inference_type in ['unmasked' , 'wt-marginals' ]:
311299 inference_function = unmasked_wt_score
312300 else :
@@ -317,6 +305,13 @@ def plm_inference(
317305 xs_b = get_batches (xs , dtype = int , batch_size = batch_size , keep_remaining = True , verbose = True )
318306 desc = f"Inference: { inference_type } batch (size={ batch_size } ) processing ({ device .upper ()} )'"
319307
308+ kwargs = {}
309+ if mask_token_id is not None :
310+ kwargs ["mask_token_id" ] = mask_token_id
311+
312+ if wt_structure_input_ids is not None :
313+ kwargs ["structure_input_ids" ] = wt_structure_input_ids
314+
320315 pbar = tqdm (
321316 range (len (xs_b )),
322317 desc = desc ,
@@ -327,13 +322,12 @@ def plm_inference(
327322 pll = inference_function (
328323 tokenized_sequences = torch .tensor (xs_b [i ]),
329324 wt_input_ids = wt_input_ids ,
330- structure_input_ids = wt_structure_input_ids ,
331325 attention_mask = attention_mask ,
332326 model = model ,
333- mask_token_id = mask_token_id ,
334327 train = train ,
335328 device = device ,
336- verbose = False
329+ verbose = False ,
330+ ** kwargs
337331 )
338332 scores .append (pll )
339333 return torch .cat (scores )
0 commit comments