22
33from typing import Dict , Optional , Any , List , Literal , Union
44from abc import ABC , abstractmethod
5- from pathlib import Path
65from functools import lru_cache
76import json
8- import yaml
9- import tempfile
107
118import pydantic
12- import requests
139import pandas
1410
1511from ginkgo_ai_client .utils import (
1612 fasta_sequence_iterator ,
1713 IteratorWithLength ,
18- cif_to_pdb ,
1914)
2015
2116## ---- Base classes --------------------------------------------------------------
@@ -48,7 +43,6 @@ def parse_response(self, results: Dict) -> Any:
4843
4944
5045class ResponseBase (pydantic .BaseModel ):
51-
5246 def write_to_jsonl (self , path : str ):
5347 with open (path , "a" ) as f :
5448 f .write (self .model_dump_json () + "\n " )
@@ -71,9 +65,18 @@ def write_to_jsonl(self, path: str):
7165 for model , sequence_type in _maskedlm_models_properties .items ()
7266)
7367
68+ SPECIAL_TOKENS = ["<mask>" , "<unk>" , "<pad>" , "<cls>" , "<eos>" ]
69+
70+
71+ def _lowercase_all_special_tokens (sequence : str ) -> str :
72+ """Lower-case all special tokens in a sequence."""
73+ for special_token in SPECIAL_TOKENS :
74+ sequence = sequence .replace (special_token .upper (), special_token )
75+ return sequence
76+
7477
7578def _validate_model_and_sequence (
76- model : str , sequence : str , allow_masks : bool = False , extra_chars : List [str ] = []
79+ model : str , sequence : str , allow_masks : bool = False , extra_tokens : List [str ] = None
7780):
7881 """Raise an error if the model is unknown or the sequence isn't compatible.
7982
@@ -92,21 +95,25 @@ def _validate_model_and_sequence(
9295 valid_models = list (_maskedlm_models_properties .keys ())
9396 if model not in valid_models :
9497 raise ValueError (f"Model '{ model } ' unknown. Sould be one of { valid_models } " )
98+ extra_tokens = SPECIAL_TOKENS + (extra_tokens or [])
9599 sequence_type = _maskedlm_models_properties [model ]
96- if allow_masks :
97- sequence = sequence .replace ("<mask>" , "" )
98- chars = {
99- "dna" : set ("ATGC" ),
100- "dna-iupac" : set ("ATGCNRSYWKMDHBV" ),
101- "protein" : set ("ACDEFGHIKLMNPQRSTVWY" ),
102- }[sequence_type ]
103100
104- chars = chars .union (set ([e .upper () for e in extra_chars ]))
101+ sequence_without_extra_tokens = sequence
102+ for token in extra_tokens :
103+ sequence_without_extra_tokens = sequence_without_extra_tokens .replace (token , "" )
105104
106- if not set (sequence .upper ()).issubset (chars ):
105+ allowed_chars = {
106+ "dna" : set ("ATGCatgc" ),
107+ "dna-iupac" : set ("ATGCNRSYWKMDHBVatgcnywkmdbvh" ),
108+ "protein" : set ("ACDEFGHIKLMNPQRSTVWY" ), # only uppercase allowed
109+ }[sequence_type ]
110+ unallowed_chars = set (sequence_without_extra_tokens ) - allowed_chars
111+ if unallowed_chars :
107112 raise ValueError (
108113 f"Model { model } requires the sequence to only contain "
109- f"the following characters (lower or upper-case): { '' .join (chars )} "
114+ f"the following characters: { '' .join (sorted (allowed_chars ))} "
115+ f"and the extra tokens { extra_tokens } (these can be upper-case). "
116+ f"The following unparsable characters were found: { '' .join (sorted (unallowed_chars ))} "
110117 )
111118
112119
@@ -166,8 +173,9 @@ def parse_response(self, results: Dict) -> EmbeddingResponse:
166173
167174 @pydantic .model_validator (mode = "after" )
168175 def check_model_and_sequence_compatibility (cls , query ):
176+ query .sequence = _lowercase_all_special_tokens (query .sequence )
169177 sequence , model = query .sequence , query .model
170- _validate_model_and_sequence (model = model , sequence = sequence , allow_masks = False )
178+ _validate_model_and_sequence (model = model , sequence = sequence )
171179 return query
172180
173181 @classmethod
@@ -237,8 +245,9 @@ def parse_response(self, response: Dict) -> SequenceResponse:
237245
238246 @pydantic .model_validator (mode = "after" )
239247 def check_model_and_sequence_compatibility (cls , query ):
248+ query .sequence = _lowercase_all_special_tokens (query .sequence )
240249 sequence , model = query .sequence , query .model
241- _validate_model_and_sequence (model = model , sequence = sequence , allow_masks = True )
250+ _validate_model_and_sequence (model = model , sequence = sequence )
242251 return query
243252
244253
@@ -525,7 +534,6 @@ class RNADiffusionMaskedQuery(QueryBase):
525534 query_name : Optional [str ] = None
526535
527536 def to_request_params (self ) -> Dict :
528-
529537 data = {
530538 "three_utr" : self .three_utr ,
531539 "five_utr" : self .five_utr ,
@@ -572,12 +580,18 @@ def get_species_dataframe(cls):
572580
573581 @pydantic .model_validator (mode = "after" )
574582 def validate_query (cls , query ):
575-
576- _validate_model_and_sequence (query .model , query .three_utr , allow_masks = True )
577- _validate_model_and_sequence (query .model , query .five_utr , allow_masks = True )
583+ query .three_utr = _lowercase_all_special_tokens (query .three_utr )
584+ query .five_utr = _lowercase_all_special_tokens (query .five_utr )
585+ query .protein_sequence = _lowercase_all_special_tokens (query .protein_sequence )
586+ _validate_model_and_sequence (query .model , query .three_utr )
587+ _validate_model_and_sequence (query .model , query .five_utr )
578588 # extra char for "-" that denotes end of the protein sequence
589+ if "<mask>" in query .protein_sequence :
590+ raise ValueError (
591+ "protein_sequence cannot contain <mask> in the RNA diffusion model."
592+ )
579593 _validate_model_and_sequence (
580- "esm2-650M" , query .protein_sequence , allow_masks = False , extra_chars = ["-" ]
594+ "esm2-650M" , query .protein_sequence , extra_tokens = ["-" ]
581595 )
582596
583597 if query .species not in cls .get_species_dataframe ().Species .tolist ():
@@ -679,7 +693,6 @@ def validate_query(cls, query):
679693 _validate_model_and_sequence (
680694 model = model ,
681695 sequence = sequence ,
682- allow_masks = True ,
683696 )
684697 # Validate temperature
685698 if not 0 <= query .temperature <= 1 :
0 commit comments