Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,6 @@ jobs:
uses: ./.github/actions/setup-python-env
with:
python-version: "3.11"
- name: Check types with mypy
- name: Check types with ty
run: |
uv run python -m mypy --install-types --non-interactive colandr
uv run python -m ty check
7 changes: 5 additions & 2 deletions colandr/lib/extractors/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import logging
import typing as t

from spacy.tokens import Span

Expand Down Expand Up @@ -50,7 +51,9 @@ def extract_locations(self, record_id: int, text: str) -> list[Metadata]:
return []

processed_docs_iter = process_texts_into_docs([text], max_len=None)
doc = next(processed_docs_iter, None)
doc = next(iter(processed_docs_iter), None)
if doc is None:
return []

# Get all sentences
sentences = list(doc.sents)
Expand Down Expand Up @@ -90,7 +93,7 @@ def extract_locations(self, record_id: int, text: str) -> list[Metadata]:
return self._group_locations(record_id, locations)

def _group_locations(
self, record_id: int, locations: list[Metadata]
self, record_id: int, locations: list[dict[str, t.Any]]
) -> list[Metadata]:
"""
Group locations by name and sort by frequency.
Expand Down
5 changes: 2 additions & 3 deletions colandr/lib/extractors/review_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def transform(self, x: pd.Series) -> np.ndarray:
Returns:
A 2D NumPy array of shape (n_samples, n_features).
"""
return np.vstack(x)
return np.vstack(x.tolist())


class ReviewModel:
Expand Down Expand Up @@ -447,8 +447,7 @@ def _process_text(self, text_content: str) -> tuple[pd.DataFrame, list[dict]]:
processed_docs_iter = process_texts_into_docs(
[main_content], max_len=None, exclude=("ner",)
)
doc = next(processed_docs_iter, None)

doc = next(iter(processed_docs_iter), None)
return self._extract_features_from_doc(doc)

def _is_valid_sentence(self, sent: Optional[Span]) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions colandr/lib/fileio/studies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _from_stream(self, stream: t.IO[bytes], encoding: str) -> str:
# ).read()
return data

def _standardize_field_names(self, record: dict[str, object]) -> dict[str, object]:
def _standardize_field_names(self, record: dict[str, t.Any]) -> dict[str, t.Any]:
record = {key.lower().replace(" ", "_"): value for key, value in record.items()}
if self.field_alt_names:
# only one alt name per field? take this faster path
Expand All @@ -127,7 +127,7 @@ def _standardize_field_names(self, record: dict[str, object]) -> dict[str, objec
break
return record

def _sanitize_field_values(self, record: dict[str, object]) -> dict[str, object]:
def _sanitize_field_values(self, record: dict[str, t.Any]) -> dict[str, t.Any]:
if self.field_sanitizers:
for field, sanitizers in self.field_sanitizers.items():
if field in record:
Expand Down
16 changes: 10 additions & 6 deletions colandr/lib/nlp/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def __init__(
self._avg_doc_length = None

def _validate_vocabulary(
self, vocabulary: dict[str, int] | Iterable[str]
) -> tuple[dict[str, int], bool]:
self, vocabulary: t.Optional[dict[str, int] | Iterable[str]]
) -> tuple[dict[str, int] | None, bool]:
"""
Validate an input vocabulary. If it's a mapping, ensure that term ids
are unique and compact (i.e. without any gaps between 0 and the number
Expand Down Expand Up @@ -191,7 +191,7 @@ def _validate_vocabulary(
is_fixed = True
else:
is_fixed = False
return (vocabulary, is_fixed)
return (vocabulary, is_fixed) # ty: ignore[invalid-return-type]

def _check_vocabulary(self):
"""
Expand All @@ -211,9 +211,11 @@ def id_to_term(self) -> dict[int, str]:
generated if needed, and it is automatically kept in sync with the
corresponding vocabulary.
"""
self._check_vocabulary()
if len(self.id_to_term_) != self.vocabulary_terms:
self.id_to_term_ = {
term_id: term_str for term_str, term_id in self.vocabulary_terms.items()
term_id: term_str
for term_str, term_id in self.vocabulary_terms.items() # ty: ignore[possibly-missing-attribute]
}
return self.id_to_term_

Expand All @@ -235,7 +237,8 @@ def terms_list(self) -> list[str]:
return [
term_str
for term_str, _ in sorted(
self.vocabulary_terms.items(), key=operator.itemgetter(1)
self.vocabulary_terms.items(), # ty: ignore[possibly-missing-attribute]
key=operator.itemgetter(1),
)
]

Expand Down Expand Up @@ -389,6 +392,7 @@ def _count_terms(
vocabulary.default_factory = vocabulary.__len__
else:
vocabulary = self.vocabulary_terms
assert vocabulary is not None

indices = array(str("i"))
indptr = array(str("i"), [0])
Expand Down Expand Up @@ -421,7 +425,7 @@ def _count_terms(
# pretty sure this is a good thing to do... o_O
doc_term_matrix.sort_indices()

return doc_term_matrix, vocabulary
return (doc_term_matrix, vocabulary)

def _filter_terms(
self, doc_term_matrix: sp.csr_matrix, vocabulary: dict[str, int]
Expand Down
7 changes: 5 additions & 2 deletions colandr/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from celery.utils.log import get_task_logger
from flask import current_app
from flask_mail import Message
from sqlalchemy.dialects import postgresql as pg

from . import models
from .api.v1 import schemas
Expand Down Expand Up @@ -158,7 +159,9 @@ def deduplicate_citations(review_id: int):
sa.select(models.Study.id)
.where(models.Study.review_id == review_id)
# .where(models.Study.citation_status.in_(["included", "excluded"]))
.where(models.Study.citation_status == sa.any_(["included", "excluded"]))
.where(
models.Study.citation_status == sa.any_(pg.array(["included", "excluded"]))
)
)
incl_excl_sids = set(db.session.execute(stmt).scalars().all())

Expand Down Expand Up @@ -197,7 +200,7 @@ def deduplicate_citations(review_id: int):
)
.where(models.Study.review_id == review_id)
# .where(models.Study.id.in_(int_sids))
.where(models.Study.id == sa.any_(int_sids))
.where(models.Study.id == sa.any_(pg.array(int_sids)))
.order_by(sa.text("n_null_cols ASC"))
.limit(1)
)
Expand Down
25 changes: 15 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ Repository = "https://github.com/datakind/permanent-colandr-back"
[dependency-groups]
dev = [
"httpx~=0.28.0",
"mypy~=1.0",
"pytest~=9.0",
"pytest-postgresql~=7.0",
"SQLAlchemy-Utils~=0.42.0",
# TODO: update ty once officially out of beta
"ty~=0.0.7",
"ruff~=0.14.0",
]

Expand All @@ -80,15 +81,6 @@ required-version = ">=0.8.0,<0.10.0"
module-name = "colandr"
module-root = ""

[tool.mypy]
files = ["colandr/**/*.py"]
python_version = "3.12"
pretty = true
ignore_errors = true
allow_redefinition = true
ignore_missing_imports = true
follow_imports = "silent"

[tool.pytest]
minversion = "9.0"
addopts = ["--verbose"]
Expand Down Expand Up @@ -124,3 +116,16 @@ ignore = ["E501", "E711", "F401", "PLW2901"]
lines-after-imports = 2
known-first-party = ["colandr"]
known-third-party = ["alembic"]

[tool.ty.environment]
root = ["./colandr"]

[tool.ty.rules]
# ty appears to be struggling with relative imports :shrug:
unresolved-import = "ignore"

[tool.ty.src]
exclude = ["migrations", "notebooks", "tests"]

[tool.ty.terminal]
output-format = "full"
5 changes: 4 additions & 1 deletion tests/lib/nlp/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from spacy.language import Language
from spacy.tokens import Doc

from colandr.lib.nlp import utils
Expand Down Expand Up @@ -125,4 +126,6 @@ def test_process_texts_into_docs(texts, max_len, fallback_lang, app):
spacy_lang = utils.load_spacy_lang(
utils.get_lang_to_models()["en"][0], exclude=("parser", "ner")
)
assert spacy_lang(texts[0]).to_bytes() == docs[0].to_bytes()
doc = docs[0]
assert isinstance(spacy_lang, Language) and isinstance(doc, Doc) # type guards
assert spacy_lang(texts[0]).to_bytes() == doc.to_bytes()
Loading