Skip to content

Commit c0fda38

Browse files
committed
Add find_or_create_target_gene_by_sequence and find_or_create_target_gene_by_accession functions to avoid recreating target genes on each score set update
1 parent 250fdb7 commit c0fda38

File tree

3 files changed

+158
-35
lines changed

3 files changed

+158
-35
lines changed

src/mavedb/lib/target_genes.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import logging
22
from typing import Optional
33

4-
from sqlalchemy import func, or_
4+
from mavedb.models.target_accession import TargetAccession
5+
from mavedb.models.target_sequence import TargetSequence
6+
from mavedb.models.taxonomy import Taxonomy
7+
from sqlalchemy import func, or_, and_
58
from sqlalchemy.orm import Session
69

710
from mavedb.lib.logging.context import logging_context, save_to_logging_context
@@ -13,6 +16,121 @@
1316

1417
logger = logging.getLogger(__name__)
1518

19+
def find_or_create_target_gene_by_accession(
20+
db: Session,
21+
score_set_id: int,
22+
tg: dict,
23+
tg_accession: dict,
24+
) -> TargetGene:
25+
"""
26+
Find or create a target gene for a score set by accession.
27+
"""
28+
target_gene = None
29+
logger.info(
30+
msg=f"Searching for existing target gene by accession within score set {score_set_id}.",
31+
extra=logging_context(),
32+
)
33+
if tg_accession is not None and tg_accession.get("accession"):
34+
target_gene = (
35+
db.query(TargetGene)
36+
.filter(
37+
and_(
38+
TargetGene.target_accession.has(and_(
39+
TargetAccession.accession == tg_accession["accession"],
40+
TargetAccession.assembly == tg_accession["assembly"],
41+
TargetAccession.gene == tg_accession["gene"],
42+
TargetAccession.is_base_editor == tg_accession.get("is_base_editor", False),
43+
)),
44+
TargetGene.name == tg["name"],
45+
TargetGene.category == tg["category"],
46+
TargetGene.score_set_id == score_set_id,
47+
)
48+
)
49+
.first()
50+
)
51+
52+
if target_gene is None:
53+
target_accession = TargetAccession(
54+
**tg_accession
55+
)
56+
target_gene = TargetGene(
57+
**tg,
58+
score_set_id=score_set_id,
59+
target_accession=target_accession,
60+
)
61+
db.add(target_gene)
62+
db.commit()
63+
db.refresh(target_gene)
64+
logger.info(
65+
msg=f"Created new target gene '{target_gene.name}' with ID {target_gene.id}.",
66+
extra=logging_context(),
67+
)
68+
else:
69+
logger.info(
70+
msg=f"Found existing target gene '{target_gene.name}' with ID {target_gene.id}.",
71+
extra=logging_context(),
72+
)
73+
74+
return target_gene
75+
76+
def find_or_create_target_gene_by_sequence(
77+
db: Session,
78+
score_set_id: int,
79+
tg: dict,
80+
tg_sequence: dict,
81+
) -> TargetGene:
82+
"""
83+
Find or create a target gene for a score set by sequence.
84+
"""
85+
target_gene = None
86+
logger.info(
87+
msg=f"Searching for existing target gene by sequence within score set {score_set_id}.",
88+
extra=logging_context(),
89+
)
90+
if tg_sequence is not None and tg_sequence.get("sequence"):
91+
target_gene = (
92+
db.query(TargetGene)
93+
.filter(
94+
and_(
95+
TargetGene.target_sequence.has(and_(
96+
TargetSequence.sequence == tg_sequence["sequence"],
97+
TargetSequence.sequence_type == tg_sequence["sequence_type"],
98+
TargetSequence.taxonomy.has(
99+
Taxonomy.id == tg_sequence["taxonomy"].id
100+
),
101+
TargetSequence.label == tg_sequence["label"],
102+
)),
103+
TargetGene.name == tg["name"],
104+
TargetGene.category == tg["category"],
105+
TargetGene.score_set_id == score_set_id,
106+
)
107+
)
108+
.first()
109+
)
110+
111+
if target_gene is None:
112+
target_sequence = TargetSequence(
113+
**tg_sequence
114+
)
115+
target_gene = TargetGene(
116+
**tg,
117+
score_set_id=score_set_id,
118+
target_sequence=target_sequence,
119+
)
120+
db.add(target_gene)
121+
db.commit()
122+
db.refresh(target_gene)
123+
logger.info(
124+
msg=f"Created new target gene '{target_gene.name}' with ID {target_gene.id}.",
125+
extra=logging_context(),
126+
)
127+
else:
128+
logger.info(
129+
msg=f"Found existing target gene '{target_gene.name}' with ID {target_gene.id}.",
130+
extra=logging_context(),
131+
)
132+
133+
return target_gene
16134

17135
def search_target_genes(
18136
db: Session,

src/mavedb/routers/score_sets.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fastapi.responses import StreamingResponse
1212
from ga4gh.va_spec.acmg_2015 import VariantPathogenicityEvidenceLine
1313
from ga4gh.va_spec.base.core import Statement, ExperimentalVariantFunctionalImpactStudyResult
14+
from mavedb.lib.target_genes import find_or_create_target_gene_by_accession, find_or_create_target_gene_by_sequence
1415
from mavedb.view_models.contributor import ContributorCreate
1516
from mavedb.view_models.doi_identifier import DoiIdentifierCreate
1617
from mavedb.view_models.publication_identifier import PublicationIdentifierCreate
@@ -123,11 +124,6 @@ async def enqueue_variant_creation(
123124
scores_column_metadata = item.dataset_columns.get("scores_column_metadata")
124125
counts_column_metadata = item.dataset_columns.get("counts_column_metadata")
125126

126-
# Although this is also updated within the variant creation job, update it here
127-
# as well so that we can display the proper UI components (queue invocation delay
128-
# races the score set GET request).
129-
item.processing_state = ProcessingState.processing
130-
131127
# await the insertion of this job into the worker queue, not the job itself.
132128
job = await worker.enqueue_job(
133129
"create_variants_for_score_set",
@@ -161,7 +157,7 @@ async def score_set_update(
161157
item_update_dict: dict[str, Any] = item_update.model_dump(exclude_unset=exclude_unset)
162158

163159
item = db.query(ScoreSet).filter(ScoreSet.urn == urn).one_or_none()
164-
if not item:
160+
if not item or item.id is None:
165161
logger.info(msg="Failed to update score set; The requested score set does not exist.", extra=logging_context())
166162
raise HTTPException(status_code=404, detail=f"score set with URN '{urn}' not found")
167163

@@ -255,17 +251,9 @@ async def score_set_update(
255251
if "score_ranges" in item_update_dict:
256252
item.score_ranges = item_update_dict.get("score_ranges", null())
257253

258-
# If item_update_dict includes target_genes, delete the old target gene, WT sequence, and reference map. These will be deleted when we set the score set's
259-
# target_gene to None, because we have set cascade='all,delete-orphan' on ScoreSet.target_gene. (Since the
260-
# relationship is defined with the target gene as owner, this is actually set up in the backref attribute of
261-
# TargetGene.score_set.)
262-
#
263-
# We must flush our database queries now so that the old target gene will be deleted before inserting a new one
264-
# with the same score_set_id.
265-
266254
if "target_genes" in item_update_dict:
267-
item.target_genes = []
268-
db.flush()
255+
assert all(tg.id is not None for tg in item.target_genes)
256+
existing_target_ids: list[int] = [tg.id for tg in item.target_genes if tg.id is not None]
269257

270258
targets: List[TargetGene] = []
271259
accessions = False
@@ -301,17 +289,11 @@ async def score_set_update(
301289
# View model validation rules enforce that sequences must have a label defined if there are more than one
302290
# targets defined on a score set.
303291
seq_label = gene.target_sequence.label if gene.target_sequence.label is not None else gene.name
304-
target_sequence = TargetSequence(
305-
**jsonable_encoder(
306-
gene.target_sequence,
307-
by_alias=False,
308-
exclude={"taxonomy", "label"},
309-
),
310-
taxonomy=taxonomy,
311-
label=seq_label,
312-
)
313-
target_gene = TargetGene(
314-
**jsonable_encoder(
292+
293+
target_gene = target_gene = find_or_create_target_gene_by_sequence(
294+
db,
295+
score_set_id=item.id,
296+
tg=jsonable_encoder(
315297
gene,
316298
by_alias=False,
317299
exclude={
@@ -320,7 +302,11 @@ async def score_set_update(
320302
"target_accession",
321303
},
322304
),
323-
target_sequence=target_sequence,
305+
tg_sequence={
306+
**jsonable_encoder(gene.target_sequence, by_alias=False, exclude={"taxonomy", "label"}),
307+
"taxonomy": taxonomy,
308+
"label": seq_label,
309+
}
324310
)
325311

326312
elif gene.target_accession:
@@ -333,9 +319,11 @@ async def score_set_update(
333319
"MaveDB does not support score-sets with both sequence and accession based targets. Please re-submit this scoreset using only one type of target."
334320
)
335321
accessions = True
336-
target_accession = TargetAccession(**jsonable_encoder(gene.target_accession, by_alias=False))
337-
target_gene = TargetGene(
338-
**jsonable_encoder(
322+
323+
target_gene = find_or_create_target_gene_by_accession(
324+
db,
325+
score_set_id=item.id,
326+
tg=jsonable_encoder(
339327
gene,
340328
by_alias=False,
341329
exclude={
@@ -344,7 +332,7 @@ async def score_set_update(
344332
"target_accession",
345333
},
346334
),
347-
target_accession=target_accession,
335+
tg_accession=jsonable_encoder(gene.target_accession, by_alias=False),
348336
)
349337
else:
350338
save_to_logging_context({"failing_target": gene})
@@ -365,7 +353,13 @@ async def score_set_update(
365353
targets.append(target_gene)
366354

367355
item.target_genes = targets
368-
should_create_variants = True if item.variants else False
356+
357+
assert all(tg.id is not None for tg in item.target_genes)
358+
current_target_ids: list[int] = [tg.id for tg in item.target_genes if tg.id is not None]
359+
360+
if sorted(existing_target_ids) != sorted(current_target_ids):
361+
logger.info(msg=f"Target genes have changed for score set {item.id}", extra=logging_context())
362+
should_create_variants = True if item.variants else False
369363

370364
else:
371365
logger.debug(msg="Skipped score range and target gene update. Score set is published.", extra=logging_context())
@@ -1131,6 +1125,7 @@ async def create_score_set(
11311125
# View model validation rules enforce that sequences must have a label defined if there are more than one
11321126
# targets defined on a score set.
11331127
seq_label = gene.target_sequence.label if gene.target_sequence.label is not None else gene.name
1128+
11341129
target_sequence = TargetSequence(
11351130
**jsonable_encoder(gene.target_sequence, by_alias=False, exclude={"taxonomy", "label"}),
11361131
taxonomy=taxonomy,
@@ -1454,8 +1449,17 @@ async def update_score_set(
14541449
should_create_variants = itemUpdateResult["should_create_variants"]
14551450

14561451
if should_create_variants:
1452+
# Although this is also updated within the variant creation job, update it here
1453+
# as well so that we can display the proper UI components (queue invocation delay
1454+
# races the score set GET request).
1455+
updatedItem.processing_state = ProcessingState.processing
1456+
14571457
await enqueue_variant_creation(item=updatedItem, user_data=user_data, worker=worker)
14581458

1459+
db.add(updatedItem)
1460+
db.commit()
1461+
db.refresh(updatedItem)
1462+
14591463
enriched_experiment = enrich_experiment_with_num_score_sets(updatedItem.experiment, user_data)
14601464
return score_set.ScoreSet.model_validate(updatedItem).copy(update={"experiment": enriched_experiment})
14611465

src/mavedb/view_models/target_gene.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Optional, Sequence
33
from typing_extensions import Self
44

5-
from pydantic import Field, model_validator
5+
from pydantic import Field, model_validator, ConfigDict
66

77
from mavedb.lib.validation.exceptions import ValidationError
88
from mavedb.lib.validation.transform import transform_external_identifier_offsets_to_list, transform_score_set_to_urn
@@ -25,6 +25,7 @@ class TargetGeneBase(BaseModel):
2525
category: TargetCategory
2626
external_identifiers: Sequence[external_gene_identifier_offset.ExternalGeneIdentifierOffsetBase]
2727

28+
model_config = ConfigDict(from_attributes=True)
2829

2930
class TargetGeneModify(TargetGeneBase):
3031
pass

0 commit comments

Comments
 (0)