Skip to content

Commit 68b0938

Browse files
author
Valentin Zulkower
committed
better validation, better handling of special tokens, better error messages
1 parent 6d53a69 commit 68b0938

File tree

3 files changed

+76
-25
lines changed

3 files changed

+76
-25
lines changed

ginkgo_ai_client/queries.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,15 @@
22

33
from typing import Dict, Optional, Any, List, Literal, Union
44
from abc import ABC, abstractmethod
5-
from pathlib import Path
65
from functools import lru_cache
76
import json
8-
import yaml
9-
import tempfile
107

118
import pydantic
12-
import requests
139
import pandas
1410

1511
from ginkgo_ai_client.utils import (
1612
fasta_sequence_iterator,
1713
IteratorWithLength,
18-
cif_to_pdb,
1914
)
2015

2116
## ---- Base classes --------------------------------------------------------------
@@ -48,7 +43,6 @@ def parse_response(self, results: Dict) -> Any:
4843

4944

5045
class ResponseBase(pydantic.BaseModel):
51-
5246
def write_to_jsonl(self, path: str):
5347
with open(path, "a") as f:
5448
f.write(self.model_dump_json() + "\n")
@@ -71,9 +65,18 @@ def write_to_jsonl(self, path: str):
7165
for model, sequence_type in _maskedlm_models_properties.items()
7266
)
7367

68+
SPECIAL_TOKENS = ["<mask>", "<unk>", "<pad>", "<cls>", "<eos>"]
69+
70+
71+
def _lowercase_all_special_tokens(sequence: str) -> str:
72+
"""Lower-case all special tokens in a sequence."""
73+
for special_token in SPECIAL_TOKENS:
74+
sequence = sequence.replace(special_token.upper(), special_token)
75+
return sequence
76+
7477

7578
def _validate_model_and_sequence(
76-
model: str, sequence: str, allow_masks: bool = False, extra_chars: List[str] = []
79+
model: str, sequence: str, allow_masks: bool = False, extra_tokens: List[str] = None
7780
):
7881
"""Raise an error if the model is unknown or the sequence isn't compatible.
7982
@@ -92,21 +95,25 @@ def _validate_model_and_sequence(
9295
valid_models = list(_maskedlm_models_properties.keys())
9396
if model not in valid_models:
9497
raise ValueError(f"Model '{model}' unknown. Sould be one of {valid_models}")
98+
extra_tokens = SPECIAL_TOKENS + (extra_tokens or [])
9599
sequence_type = _maskedlm_models_properties[model]
96-
if allow_masks:
97-
sequence = sequence.replace("<mask>", "")
98-
chars = {
99-
"dna": set("ATGC"),
100-
"dna-iupac": set("ATGCNRSYWKMDHBV"),
101-
"protein": set("ACDEFGHIKLMNPQRSTVWY"),
102-
}[sequence_type]
103100

104-
chars = chars.union(set([e.upper() for e in extra_chars]))
101+
sequence_without_extra_tokens = sequence
102+
for token in extra_tokens:
103+
sequence_without_extra_tokens = sequence_without_extra_tokens.replace(token, "")
105104

106-
if not set(sequence.upper()).issubset(chars):
105+
allowed_chars = {
106+
"dna": set("ATGCatgc"),
107+
"dna-iupac": set("ATGCNRSYWKMDHBVatgcnywkmdbvh"),
108+
"protein": set("ACDEFGHIKLMNPQRSTVWY"), # only uppercase allowed
109+
}[sequence_type]
110+
unallowed_chars = set(sequence_without_extra_tokens) - allowed_chars
111+
if unallowed_chars:
107112
raise ValueError(
108113
f"Model {model} requires the sequence to only contain "
109-
f"the following characters (lower or upper-case): {''.join(chars)}"
114+
f"the following characters: {''.join(sorted(allowed_chars))} "
115+
f"and the extra tokens {extra_tokens} (these can be upper-case). "
116+
f"The following unparsable characters were found: {''.join(sorted(unallowed_chars))}"
110117
)
111118

112119

@@ -166,8 +173,9 @@ def parse_response(self, results: Dict) -> EmbeddingResponse:
166173

167174
@pydantic.model_validator(mode="after")
168175
def check_model_and_sequence_compatibility(cls, query):
176+
query.sequence = _lowercase_all_special_tokens(query.sequence)
169177
sequence, model = query.sequence, query.model
170-
_validate_model_and_sequence(model=model, sequence=sequence, allow_masks=False)
178+
_validate_model_and_sequence(model=model, sequence=sequence)
171179
return query
172180

173181
@classmethod
@@ -237,8 +245,9 @@ def parse_response(self, response: Dict) -> SequenceResponse:
237245

238246
@pydantic.model_validator(mode="after")
239247
def check_model_and_sequence_compatibility(cls, query):
248+
query.sequence = _lowercase_all_special_tokens(query.sequence)
240249
sequence, model = query.sequence, query.model
241-
_validate_model_and_sequence(model=model, sequence=sequence, allow_masks=True)
250+
_validate_model_and_sequence(model=model, sequence=sequence)
242251
return query
243252

244253

@@ -525,7 +534,6 @@ class RNADiffusionMaskedQuery(QueryBase):
525534
query_name: Optional[str] = None
526535

527536
def to_request_params(self) -> Dict:
528-
529537
data = {
530538
"three_utr": self.three_utr,
531539
"five_utr": self.five_utr,
@@ -572,12 +580,18 @@ def get_species_dataframe(cls):
572580

573581
@pydantic.model_validator(mode="after")
574582
def validate_query(cls, query):
575-
576-
_validate_model_and_sequence(query.model, query.three_utr, allow_masks=True)
577-
_validate_model_and_sequence(query.model, query.five_utr, allow_masks=True)
583+
query.three_utr = _lowercase_all_special_tokens(query.three_utr)
584+
query.five_utr = _lowercase_all_special_tokens(query.five_utr)
585+
query.protein_sequence = _lowercase_all_special_tokens(query.protein_sequence)
586+
_validate_model_and_sequence(query.model, query.three_utr)
587+
_validate_model_and_sequence(query.model, query.five_utr)
578588
# extra char for "-" that denotes end of the protein sequence
589+
if "<mask>" in query.protein_sequence:
590+
raise ValueError(
591+
"protein_sequence cannot contain <mask> in the RNA diffusion model."
592+
)
579593
_validate_model_and_sequence(
580-
"esm2-650M", query.protein_sequence, allow_masks=False, extra_chars=["-"]
594+
"esm2-650M", query.protein_sequence, extra_tokens=["-"]
581595
)
582596

583597
if query.species not in cls.get_species_dataframe().Species.tolist():
@@ -679,7 +693,6 @@ def validate_query(cls, query):
679693
_validate_model_and_sequence(
680694
model=model,
681695
sequence=sequence,
682-
allow_masks=True,
683696
)
684697
# Validate temperature
685698
if not 0 <= query.temperature <= 1:

test/test_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,25 @@ def test_masked_inference(model, sequence, expected_sequence):
2727
assert results.sequence == expected_sequence
2828

2929

30+
def test_with_uppercase_mask():
31+
client = GinkgoAIClient()
32+
results = client.send_request(
33+
MaskedInferenceQuery(
34+
sequence="MCL<MASK>YAFVATDA<MASK>DDT", model="ginkgo-aa0-650M"
35+
)
36+
)
37+
assert results.sequence == "MCLLYAFVATDADDDT"
38+
39+
40+
def test_that_unknown_tokens_are_accepted_and_not_unmasked():
41+
client = GinkgoAIClient()
42+
sequence = "MCL<unk>YAFVATDA<unk>DDT"
43+
results = client.send_request(
44+
MaskedInferenceQuery(sequence=sequence, model="ginkgo-aa0-650M")
45+
)
46+
assert results.sequence == "MCL<unk>YAFVATDA<unk>DDT"
47+
48+
3049
@pytest.mark.parametrize(
3150
"model, sequence, expected_length",
3251
[

test/test_query_creation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,25 @@ def test_that_forgetting_to_name_arguments_raises_the_better_error_message():
1414
MeanEmbeddingQuery("MLLK<mask>P", model="ginkgo-aa0-650M")
1515

1616

17+
def test_that_simple_query_creation_works():
18+
MeanEmbeddingQuery(sequence="MLLKLP", model="ginkgo-aa0-650M")
19+
20+
21+
def test_that_uppercase_mask_is_accepted():
22+
MeanEmbeddingQuery(sequence="MLLK<MASK>P", model="ginkgo-aa0-650M")
23+
24+
25+
def test_that_lowercase_protein_sequence_is_not_accepted():
26+
expected_error_message = re.escape(
27+
"Model ginkgo-aa0-650M requires the sequence to only contain the following "
28+
"characters: ACDEFGHIKLMNPQRSTVWY and the extra tokens ['<mask>', '<unk>', "
29+
"'<pad>', '<cls>', '<eos>'] (these can be upper-case). "
30+
"The following unparsable characters were found: k"
31+
)
32+
with pytest.raises(ValueError, match=expected_error_message):
33+
MeanEmbeddingQuery(sequence="MLLk<mask>P", model="ginkgo-aa0-650M")
34+
35+
1736
def test_promoter_activity_query_validation():
1837
with pytest.raises(ValueError):
1938
_query = PromoterActivityQuery(

0 commit comments

Comments
 (0)