@@ -349,11 +349,12 @@ def llm_tokenizer(llm_dict, seqs, verbose=True):
349349 if list (llm_dict .keys ())[0 ] == 'esm1v' :
350350 x_llm_seqs , _attention_mask = tokenize_sequences (
351351 seqs , tokenizer = llm_dict ['esm1v' ]['llm_tokenizer' ],
352- max_length = len (seqs [0 ]), verbose = verbose
352+ max_length = len (seqs [0 ]) + 2 , verbose = verbose
353353 )
354354 elif list (llm_dict .keys ())[0 ] == 'prosst' :
355- x_llm_seqs = prosst_simple_vocab_aa_tokenizer (
356- seqs , vocab = llm_dict ['prosst' ]['llm_vocab' ], verbose = verbose
355+ x_llm_seqs , _attention_mask = tokenize_sequences (
356+ seqs , tokenizer = llm_dict ['prosst' ]['llm_tokenizer' ],
357+ max_length = len (seqs [0 ]) + 2 , verbose = verbose
357358 )
358359 else :
359360 raise SystemError (f"Unknown LLM dictionary input:\n { list (llm_dict .keys ())[0 ]} " )
@@ -376,17 +377,29 @@ def inference(
376377 device = get_device ()
377378 if llm == 'esm' :
378379 logger .info ("Zero-shot LLM inference on test set using ESM1v..." )
379- llm_dict = esm_setup (sequences , verbose = verbose )
380+ llm_dict = esm_setup (wt_seq , sequences , verbose = verbose )
380381 if model is None :
381382 model = llm_dict ['esm1v' ]['llm_base_model' ]
382383 x_llm_test = llm_tokenizer (llm_dict , sequences , verbose )
383384 y_test_pred = esm_infer (#llm_dict['esm1v']['llm_inference_function'](
384- xs = torch .tensor (get_batches (x_llm_test , batch_size = 1 , dtype = int )),
385+ xs = torch .from_numpy (get_batches (x_llm_test , batch_size = 1 , dtype = int )),
385386 attention_mask = llm_dict ['esm1v' ]['llm_attention_mask' ],
386387 model = model ,
387388 device = device ,
388389 verbose = verbose
389390 ).cpu ()
391+ y_test_pred = plm_inference (
392+ xs = x_llm_test ,
393+ wt_input_ids = torch .tensor (llm_dict ['esm1v' ]['input_ids' ][0 ], dtype = torch .long ),
394+ attention_mask = llm_dict ['esm1v' ]['llm_attention_mask' ],
395+ model = model ,
396+ mask_token_id = llm_dict ['esm1v' ]['llm_tokenizer' ].mask_token_id ,
397+ inference_type = 'unmasked' ,
398+ batch_size = 5 ,
399+ train = False ,
400+ verbose = True
401+ ).cpu ()
402+
390403 elif llm == 'prosst' :
391404 logger .info ("Zero-shot LLM inference on test set using ProSST..." )
392405 llm_dict = prosst_setup (
@@ -395,14 +408,27 @@ def inference(
395408 if model is None :
396409 model = llm_dict ['prosst' ]['llm_base_model' ]
397410 x_llm_test = llm_tokenizer (llm_dict , sequences , verbose )
398- y_test_pred = prosst_infer (#llm_dict['prosst']['llm_inference_function'](
399- xs = x_llm_test ,
400- model = model ,
401- input_ids = llm_dict ['prosst' ]['input_ids' ],
402- attention_mask = llm_dict ['prosst' ]['llm_attention_mask' ],
403- structure_input_ids = llm_dict ['prosst' ]['structure_input_ids' ],
404- verbose = verbose ,
405- device = device
411+ #y_test_pred = prosst_infer(#llm_dict['prosst']['llm_inference_function'](
412+ # xs=x_llm_test,
413+ # model=model,
414+ # input_ids=llm_dict['prosst']['input_ids'],
415+ # attention_mask=llm_dict['prosst']['llm_attention_mask'],
416+ # structure_input_ids=llm_dict['prosst']['structure_input_ids'],
417+ # verbose=verbose,
418+ # device=device
419+ #).cpu()
420+ print ('XXX:' , np .shape (x_llm_test ))
421+ y_test_pred = plm_inference (
422+ xs = x_llm_test ,
423+ wt_input_ids = llm_dict ['prosst' ]['input_ids' ],
424+ attention_mask = llm_dict ['prosst' ]['llm_attention_mask' ],
425+ model = model ,
426+ mask_token_id = llm_dict ['prosst' ]['llm_tokenizer' ].mask_token_id ,
427+ inference_type = 'mutation-masking' ,
428+ wt_structure_input_ids = llm_dict ['prosst' ]['structure_input_ids' ],
429+ batch_size = 5 ,
430+ train = False ,
431+ verbose = True
406432 ).cpu ()
407433 else :
408434 raise RuntimeError ("Unknown LLM option." )
0 commit comments