Skip to content

Commit a3804a6

Browse files
Add py.typed marker for PEP 561 support (#151) (#153)
* switch type checker from ty to pyrefly - Replace ty with pyrefly in dev deps, CI workflow, and pyproject config - Add @overload on _clamp_difficulty/_clamp_stability for proper float/Tensor typing - Replace isinstance(x, Real) with isinstance(x, (int, float)) for pyrefly narrowing - Add wildcard match branches for exhaustive pattern checking - Remove all type ignore comments (0 errors, 0 suppressed) * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent dd4a357 commit a3804a6

File tree

4 files changed

+55
-27
lines changed

4 files changed

+55
-27
lines changed

.github/workflows/type-check.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
name: ty typecheck
1+
name: pyrefly typecheck
22

33
on: [push, pull_request]
44

55
jobs:
6-
mypy:
6+
pyrefly:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v4
@@ -15,8 +15,8 @@ jobs:
1515
python-version: '3.14'
1616

1717
- name: Install dependencies
18-
run: uv pip install ty
18+
run: uv pip install pyrefly
1919

20-
- name: Run ty
21-
run: ty check
20+
- name: Run pyrefly
21+
run: pyrefly check
2222

fsrs/py.typed

Whitespace-only changes.

fsrs/scheduler.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from __future__ import annotations
1212
from collections.abc import Sequence
13-
from numbers import Real
1413
import math
1514
from datetime import datetime, timezone, timedelta
1615
from copy import copy
@@ -21,7 +20,10 @@
2120
from fsrs.card import Card
2221
from fsrs.rating import Rating
2322
from fsrs.review_log import ReviewLog
24-
from typing import TypedDict
23+
from typing import TYPE_CHECKING, TypedDict, overload
24+
25+
if TYPE_CHECKING:
26+
from torch import Tensor # torch is optional; import only for type checking
2527
from typing_extensions import Self
2628

2729
FSRS_DEFAULT_DECAY = 0.1542
@@ -355,6 +357,9 @@ def review_card(
355357
)
356358
next_interval = timedelta(days=next_interval_days)
357359

360+
case _:
361+
raise ValueError(f"Unknown rating: {rating}")
362+
358363
case State.Review:
359364
assert card.stability is not None
360365
assert card.difficulty is not None
@@ -401,6 +406,9 @@ def review_card(
401406
)
402407
next_interval = timedelta(days=next_interval_days)
403408

409+
case _:
410+
raise ValueError(f"Unknown rating: {rating}")
411+
404412
case State.Relearning:
405413
assert card.stability is not None
406414
assert card.difficulty is not None
@@ -485,6 +493,12 @@ def review_card(
485493
)
486494
next_interval = timedelta(days=next_interval_days)
487495

496+
case _:
497+
raise ValueError(f"Unknown rating: {rating}")
498+
499+
case _:
500+
raise ValueError(f"Unknown card state: {card.state}")
501+
488502
if self.enable_fuzzing and card.state == State.Review:
489503
next_interval = self._get_fuzzed_interval(interval=next_interval)
490504

@@ -619,19 +633,27 @@ def from_json(cls, source_json: str) -> Self:
619633
source_dict: SchedulerDict = json.loads(source_json)
620634
return cls.from_dict(source_dict=source_dict)
621635

622-
def _clamp_difficulty(self, *, difficulty: float) -> float:
623-
if isinstance(difficulty, Real):
636+
@overload
637+
def _clamp_difficulty(self, *, difficulty: float) -> float: ...
638+
@overload
639+
def _clamp_difficulty(self, *, difficulty: Tensor) -> Tensor: ...
640+
def _clamp_difficulty(self, *, difficulty: float | Tensor) -> float | Tensor:
641+
if isinstance(difficulty, (int, float)):
624642
difficulty = min(max(difficulty, MIN_DIFFICULTY), MAX_DIFFICULTY)
625-
else: # type(difficulty) is torch.Tensor
626-
difficulty = difficulty.clamp(min=MIN_DIFFICULTY, max=MAX_DIFFICULTY) # type: ignore[attr-defined]
643+
else:
644+
difficulty = difficulty.clamp(min=MIN_DIFFICULTY, max=MAX_DIFFICULTY)
627645

628646
return difficulty
629647

630-
def _clamp_stability(self, *, stability: float) -> float:
631-
if isinstance(stability, Real):
648+
@overload
649+
def _clamp_stability(self, *, stability: float) -> float: ...
650+
@overload
651+
def _clamp_stability(self, *, stability: Tensor) -> Tensor: ...
652+
def _clamp_stability(self, *, stability: float | Tensor) -> float | Tensor:
653+
if isinstance(stability, (int, float)):
632654
stability = max(stability, STABILITY_MIN)
633-
else: # type(stability) is torch.Tensor
634-
stability = stability.clamp(min=STABILITY_MIN) # type: ignore[attr-defined]
655+
else:
656+
stability = stability.clamp(min=STABILITY_MIN)
635657

636658
return stability
637659

@@ -657,10 +679,8 @@ def _next_interval(self, *, stability: float) -> int:
657679
(self.desired_retention ** (1 / self._DECAY)) - 1
658680
)
659681

660-
if not isinstance(next_interval, Real): # type(next_interval) is torch.Tensor
661-
next_interval = (
662-
next_interval.detach().item() # ty: ignore[possibly-missing-attribute]
663-
)
682+
if not isinstance(next_interval, (int, float)):
683+
next_interval = next_interval.detach().item()
664684

665685
next_interval = round(next_interval) # intervals are full days
666686

@@ -678,10 +698,10 @@ def _short_term_stability(self, *, stability: float, rating: Rating) -> float:
678698
) * (stability ** -self.parameters[19])
679699

680700
if rating in (Rating.Good, Rating.Easy):
681-
if isinstance(short_term_stability_increase, Real):
701+
if isinstance(short_term_stability_increase, (int, float)):
682702
short_term_stability_increase = max(short_term_stability_increase, 1.0)
683-
else: # type(short_term_stability_increase) is torch.Tensor
684-
short_term_stability_increase = short_term_stability_increase.clamp( # ty: ignore[possibly-missing-attribute]
703+
else:
704+
short_term_stability_increase = short_term_stability_increase.clamp(
685705
min=1.0
686706
)
687707

@@ -734,6 +754,9 @@ def _next_stability(
734754
rating=rating,
735755
)
736756

757+
else:
758+
raise ValueError(f"Unknown rating: {rating}")
759+
737760
next_stability = self._clamp_stability(stability=next_stability)
738761

739762
return next_stability
@@ -803,7 +826,8 @@ def _get_fuzz_range(*, interval_days: int) -> tuple[int, int]:
803826
delta = 1.0
804827
for fuzz_range in FUZZ_RANGES:
805828
delta += fuzz_range["factor"] * max(
806-
min(interval_days, fuzz_range["end"]) - fuzz_range["start"], 0.0
829+
min(float(interval_days), fuzz_range["end"]) - fuzz_range["start"],
830+
0.0,
807831
)
808832

809833
min_ivl = int(round(interval_days - delta))

pyproject.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ requires-python = ">=3.10"
2424
Homepage = "https://github.com/open-spaced-repetition/py-fsrs"
2525

2626
[project.optional-dependencies]
27-
dev = ["pytest", "ruff", "setuptools", "torch", "numpy", "pandas", "uv", "pytest-xdist", "pdoc", "tqdm", "pytest-cov", "ty"]
27+
dev = ["pytest", "ruff", "setuptools", "torch", "numpy", "pandas", "uv", "pytest-xdist", "pdoc", "tqdm", "pytest-cov", "pyrefly"]
2828
optimizer = ["torch", "numpy", "pandas", "tqdm"]
2929

3030
[tool.pytest.ini_options]
3131
pythonpath = "."
3232
filterwarnings = ["error"]
3333

34-
[tool.ty.src]
35-
include = ["fsrs"]
36-
exclude = ["fsrs/optimizer.py"]
34+
[tool.setuptools.package-data]
35+
fsrs = ["py.typed"]
36+
37+
[tool.pyrefly]
38+
project_includes = ["fsrs"]
39+
project_excludes = ["fsrs/optimizer.py"]
40+
ignore-missing-imports = ["torch"]

0 commit comments

Comments
 (0)