Skip to content

Commit a84f41c

Browse files
authored
try out ty for type checking (#176)
* build: add ty to dev deps for type checking * types: fix some typing issues in lib code * tests: fix type issue * fix: typing issues and minor bugs in extractors * build: configure ty tool * ci: run ty in checks * fix: call ty correctly in ci * build: exclude nbs from ty check * build: remove mypy from dev deps and ci
1 parent e856aa7 commit a84f41c

File tree

9 files changed

+72
-70
lines changed

9 files changed

+72
-70
lines changed

.github/workflows/checks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,6 @@ jobs:
103103
uses: ./.github/actions/setup-python-env
104104
with:
105105
python-version: "3.11"
106-
- name: Check types with mypy
106+
- name: Check types with ty
107107
run: |
108-
uv run python -m mypy --install-types --non-interactive colandr
108+
uv run python -m ty check

colandr/lib/extractors/locations.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections
44
import logging
5+
import typing as t
56

67
from spacy.tokens import Span
78

@@ -50,7 +51,9 @@ def extract_locations(self, record_id: int, text: str) -> list[Metadata]:
5051
return []
5152

5253
processed_docs_iter = process_texts_into_docs([text], max_len=None)
53-
doc = next(processed_docs_iter, None)
54+
doc = next(iter(processed_docs_iter), None)
55+
if doc is None:
56+
return []
5457

5558
# Get all sentences
5659
sentences = list(doc.sents)
@@ -90,7 +93,7 @@ def extract_locations(self, record_id: int, text: str) -> list[Metadata]:
9093
return self._group_locations(record_id, locations)
9194

9295
def _group_locations(
93-
self, record_id: int, locations: list[Metadata]
96+
self, record_id: int, locations: list[dict[str, t.Any]]
9497
) -> list[Metadata]:
9598
"""
9699
Group locations by name and sort by frequency.

colandr/lib/extractors/review_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def transform(self, x: pd.Series) -> np.ndarray:
8080
Returns:
8181
A 2D NumPy array of shape (n_samples, n_features).
8282
"""
83-
return np.vstack(x)
83+
return np.vstack(x.tolist())
8484

8585

8686
class ReviewModel:
@@ -447,8 +447,7 @@ def _process_text(self, text_content: str) -> tuple[pd.DataFrame, list[dict]]:
447447
processed_docs_iter = process_texts_into_docs(
448448
[main_content], max_len=None, exclude=("ner",)
449449
)
450-
doc = next(processed_docs_iter, None)
451-
450+
doc = next(iter(processed_docs_iter), None)
452451
return self._extract_features_from_doc(doc)
453452

454453
def _is_valid_sentence(self, sent: Optional[Span]) -> bool:

colandr/lib/fileio/studies/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _from_stream(self, stream: t.IO[bytes], encoding: str) -> str:
112112
# ).read()
113113
return data
114114

115-
def _standardize_field_names(self, record: dict[str, object]) -> dict[str, object]:
115+
def _standardize_field_names(self, record: dict[str, t.Any]) -> dict[str, t.Any]:
116116
record = {key.lower().replace(" ", "_"): value for key, value in record.items()}
117117
if self.field_alt_names:
118118
# only one alt name per field? take this faster path
@@ -127,7 +127,7 @@ def _standardize_field_names(self, record: dict[str, object]) -> dict[str, objec
127127
break
128128
return record
129129

130-
def _sanitize_field_values(self, record: dict[str, object]) -> dict[str, object]:
130+
def _sanitize_field_values(self, record: dict[str, t.Any]) -> dict[str, t.Any]:
131131
if self.field_sanitizers:
132132
for field, sanitizers in self.field_sanitizers.items():
133133
if field in record:

colandr/lib/nlp/representations.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148
self._avg_doc_length = None
149149

150150
def _validate_vocabulary(
151-
self, vocabulary: dict[str, int] | Iterable[str]
152-
) -> tuple[dict[str, int], bool]:
151+
self, vocabulary: t.Optional[dict[str, int] | Iterable[str]]
152+
) -> tuple[dict[str, int] | None, bool]:
153153
"""
154154
Validate an input vocabulary. If it's a mapping, ensure that term ids
155155
are unique and compact (i.e. without any gaps between 0 and the number
@@ -191,7 +191,7 @@ def _validate_vocabulary(
191191
is_fixed = True
192192
else:
193193
is_fixed = False
194-
return (vocabulary, is_fixed)
194+
return (vocabulary, is_fixed) # ty: ignore[invalid-return-type]
195195

196196
def _check_vocabulary(self):
197197
"""
@@ -211,9 +211,11 @@ def id_to_term(self) -> dict[int, str]:
211211
generated if needed, and it is automatically kept in sync with the
212212
corresponding vocabulary.
213213
"""
214+
self._check_vocabulary()
214215
if len(self.id_to_term_) != self.vocabulary_terms:
215216
self.id_to_term_ = {
216-
term_id: term_str for term_str, term_id in self.vocabulary_terms.items()
217+
term_id: term_str
218+
for term_str, term_id in self.vocabulary_terms.items() # ty: ignore[possibly-missing-attribute]
217219
}
218220
return self.id_to_term_
219221

@@ -235,7 +237,8 @@ def terms_list(self) -> list[str]:
235237
return [
236238
term_str
237239
for term_str, _ in sorted(
238-
self.vocabulary_terms.items(), key=operator.itemgetter(1)
240+
self.vocabulary_terms.items(), # ty: ignore[possibly-missing-attribute]
241+
key=operator.itemgetter(1),
239242
)
240243
]
241244

@@ -389,6 +392,7 @@ def _count_terms(
389392
vocabulary.default_factory = vocabulary.__len__
390393
else:
391394
vocabulary = self.vocabulary_terms
395+
assert vocabulary is not None
392396

393397
indices = array(str("i"))
394398
indptr = array(str("i"), [0])
@@ -421,7 +425,7 @@ def _count_terms(
421425
# pretty sure this is a good thing to do... o_O
422426
doc_term_matrix.sort_indices()
423427

424-
return doc_term_matrix, vocabulary
428+
return (doc_term_matrix, vocabulary)
425429

426430
def _filter_terms(
427431
self, doc_term_matrix: sp.csr_matrix, vocabulary: dict[str, int]

colandr/tasks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from celery.utils.log import get_task_logger
1313
from flask import current_app
1414
from flask_mail import Message
15+
from sqlalchemy.dialects import postgresql as pg
1516

1617
from . import models
1718
from .api.v1 import schemas
@@ -158,7 +159,9 @@ def deduplicate_citations(review_id: int):
158159
sa.select(models.Study.id)
159160
.where(models.Study.review_id == review_id)
160161
# .where(models.Study.citation_status.in_(["included", "excluded"]))
161-
.where(models.Study.citation_status == sa.any_(["included", "excluded"]))
162+
.where(
163+
models.Study.citation_status == sa.any_(pg.array(["included", "excluded"]))
164+
)
162165
)
163166
incl_excl_sids = set(db.session.execute(stmt).scalars().all())
164167

@@ -197,7 +200,7 @@ def deduplicate_citations(review_id: int):
197200
)
198201
.where(models.Study.review_id == review_id)
199202
# .where(models.Study.id.in_(int_sids))
200-
.where(models.Study.id == sa.any_(int_sids))
203+
.where(models.Study.id == sa.any_(pg.array(int_sids)))
201204
.order_by(sa.text("n_null_cols ASC"))
202205
.limit(1)
203206
)

pyproject.toml

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@ Repository = "https://github.com/datakind/permanent-colandr-back"
6262
[dependency-groups]
6363
dev = [
6464
"httpx~=0.28.0",
65-
"mypy~=1.0",
6665
"pytest~=9.0",
6766
"pytest-postgresql~=7.0",
6867
"SQLAlchemy-Utils~=0.42.0",
68+
# TODO: update ty once officially out of beta
69+
"ty~=0.0.7",
6970
"ruff~=0.14.0",
7071
]
7172

@@ -80,15 +81,6 @@ required-version = ">=0.8.0,<0.10.0"
8081
module-name = "colandr"
8182
module-root = ""
8283

83-
[tool.mypy]
84-
files = ["colandr/**/*.py"]
85-
python_version = "3.12"
86-
pretty = true
87-
ignore_errors = true
88-
allow_redefinition = true
89-
ignore_missing_imports = true
90-
follow_imports = "silent"
91-
9284
[tool.pytest]
9385
minversion = "9.0"
9486
addopts = ["--verbose"]
@@ -124,3 +116,16 @@ ignore = ["E501", "E711", "F401", "PLW2901"]
124116
lines-after-imports = 2
125117
known-first-party = ["colandr"]
126118
known-third-party = ["alembic"]
119+
120+
[tool.ty.environment]
121+
root = ["./colandr"]
122+
123+
[tool.ty.rules]
124+
# ty appears to be struggling with relative imports :shrug:
125+
unresolved-import = "ignore"
126+
127+
[tool.ty.src]
128+
exclude = ["migrations", "notebooks", "tests"]
129+
130+
[tool.ty.terminal]
131+
output-format = "full"

tests/lib/nlp/test_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from spacy.language import Language
23
from spacy.tokens import Doc
34

45
from colandr.lib.nlp import utils
@@ -125,4 +126,6 @@ def test_process_texts_into_docs(texts, max_len, fallback_lang, app):
125126
spacy_lang = utils.load_spacy_lang(
126127
utils.get_lang_to_models()["en"][0], exclude=("parser", "ner")
127128
)
128-
assert spacy_lang(texts[0]).to_bytes() == docs[0].to_bytes()
129+
doc = docs[0]
130+
assert isinstance(spacy_lang, Language) and isinstance(doc, Doc) # type guards
131+
assert spacy_lang(texts[0]).to_bytes() == doc.to_bytes()

0 commit comments

Comments
 (0)