1010
1111from __future__ import annotations
1212from collections .abc import Sequence
13- from numbers import Real
1413import math
1514from datetime import datetime , timezone , timedelta
1615from copy import copy
2120from fsrs .card import Card
2221from fsrs .rating import Rating
2322from fsrs .review_log import ReviewLog
24- from typing import TypedDict
23+ from typing import TYPE_CHECKING , TypedDict , overload
24+
25+ if TYPE_CHECKING :
26+ from torch import Tensor # torch is optional; import only for type checking
2527from typing_extensions import Self
2628
2729FSRS_DEFAULT_DECAY = 0.1542
@@ -355,6 +357,9 @@ def review_card(
355357 )
356358 next_interval = timedelta (days = next_interval_days )
357359
360+ case _:
361+ raise ValueError (f"Unknown rating: { rating } " )
362+
358363 case State .Review :
359364 assert card .stability is not None
360365 assert card .difficulty is not None
@@ -401,6 +406,9 @@ def review_card(
401406 )
402407 next_interval = timedelta (days = next_interval_days )
403408
409+ case _:
410+ raise ValueError (f"Unknown rating: { rating } " )
411+
404412 case State .Relearning :
405413 assert card .stability is not None
406414 assert card .difficulty is not None
@@ -485,6 +493,12 @@ def review_card(
485493 )
486494 next_interval = timedelta (days = next_interval_days )
487495
496+ case _:
497+ raise ValueError (f"Unknown rating: { rating } " )
498+
499+ case _:
500+ raise ValueError (f"Unknown card state: { card .state } " )
501+
488502 if self .enable_fuzzing and card .state == State .Review :
489503 next_interval = self ._get_fuzzed_interval (interval = next_interval )
490504
@@ -619,19 +633,27 @@ def from_json(cls, source_json: str) -> Self:
619633 source_dict : SchedulerDict = json .loads (source_json )
620634 return cls .from_dict (source_dict = source_dict )
621635
622- def _clamp_difficulty (self , * , difficulty : float ) -> float :
623- if isinstance (difficulty , Real ):
636+ @overload
637+ def _clamp_difficulty (self , * , difficulty : float ) -> float : ...
638+ @overload
639+ def _clamp_difficulty (self , * , difficulty : Tensor ) -> Tensor : ...
640+ def _clamp_difficulty (self , * , difficulty : float | Tensor ) -> float | Tensor :
641+ if isinstance (difficulty , (int , float )):
624642 difficulty = min (max (difficulty , MIN_DIFFICULTY ), MAX_DIFFICULTY )
625- else : # type(difficulty) is torch.Tensor
626- difficulty = difficulty .clamp (min = MIN_DIFFICULTY , max = MAX_DIFFICULTY ) # type: ignore[attr-defined]
643+ else :
644+ difficulty = difficulty .clamp (min = MIN_DIFFICULTY , max = MAX_DIFFICULTY )
627645
628646 return difficulty
629647
630- def _clamp_stability (self , * , stability : float ) -> float :
631- if isinstance (stability , Real ):
648+ @overload
649+ def _clamp_stability (self , * , stability : float ) -> float : ...
650+ @overload
651+ def _clamp_stability (self , * , stability : Tensor ) -> Tensor : ...
652+ def _clamp_stability (self , * , stability : float | Tensor ) -> float | Tensor :
653+ if isinstance (stability , (int , float )):
632654 stability = max (stability , STABILITY_MIN )
633- else : # type(stability) is torch.Tensor
634- stability = stability .clamp (min = STABILITY_MIN ) # type: ignore[attr-defined]
655+ else :
656+ stability = stability .clamp (min = STABILITY_MIN )
635657
636658 return stability
637659
@@ -657,10 +679,8 @@ def _next_interval(self, *, stability: float) -> int:
657679 (self .desired_retention ** (1 / self ._DECAY )) - 1
658680 )
659681
660- if not isinstance (next_interval , Real ): # type(next_interval) is torch.Tensor
661- next_interval = (
662- next_interval .detach ().item () # ty: ignore[possibly-missing-attribute]
663- )
682+ if not isinstance (next_interval , (int , float )):
683+ next_interval = next_interval .detach ().item ()
664684
665685 next_interval = round (next_interval ) # intervals are full days
666686
@@ -678,10 +698,10 @@ def _short_term_stability(self, *, stability: float, rating: Rating) -> float:
678698 ) * (stability ** - self .parameters [19 ])
679699
680700 if rating in (Rating .Good , Rating .Easy ):
681- if isinstance (short_term_stability_increase , Real ):
701+ if isinstance (short_term_stability_increase , ( int , float ) ):
682702 short_term_stability_increase = max (short_term_stability_increase , 1.0 )
683- else : # type(short_term_stability_increase) is torch.Tensor
684- short_term_stability_increase = short_term_stability_increase .clamp ( # ty: ignore[possibly-missing-attribute]
703+ else :
704+ short_term_stability_increase = short_term_stability_increase .clamp (
685705 min = 1.0
686706 )
687707
@@ -734,6 +754,9 @@ def _next_stability(
734754 rating = rating ,
735755 )
736756
757+ else :
758+ raise ValueError (f"Unknown rating: { rating } " )
759+
737760 next_stability = self ._clamp_stability (stability = next_stability )
738761
739762 return next_stability
@@ -803,7 +826,8 @@ def _get_fuzz_range(*, interval_days: int) -> tuple[int, int]:
803826 delta = 1.0
804827 for fuzz_range in FUZZ_RANGES :
805828 delta += fuzz_range ["factor" ] * max (
806- min (interval_days , fuzz_range ["end" ]) - fuzz_range ["start" ], 0.0
829+ min (float (interval_days ), fuzz_range ["end" ]) - fuzz_range ["start" ],
830+ 0.0 ,
807831 )
808832
809833 min_ivl = int (round (interval_days - delta ))
0 commit comments