|
4 | 4 | datetime,
|
5 | 5 | timedelta,
|
6 | 6 | )
|
7 |
| -from typing import TYPE_CHECKING |
| 7 | +from typing import ( |
| 8 | + TYPE_CHECKING, |
| 9 | + Literal, |
| 10 | + overload, |
| 11 | +) |
8 | 12 | import warnings
|
9 | 13 |
|
10 | 14 | from dateutil.relativedelta import (
|
@@ -281,6 +285,17 @@ def __repr__(self) -> str:
|
281 | 285 | repr = f"Holiday: {self.name} ({info})"
|
282 | 286 | return repr
|
283 | 287 |
|
| 288 | + @overload |
| 289 | + def dates(self, start_date, end_date, return_name: Literal[True]) -> Series: ... |
| 290 | + |
| 291 | + @overload |
| 292 | + def dates( |
| 293 | + self, start_date, end_date, return_name: Literal[False] |
| 294 | + ) -> DatetimeIndex: ... |
| 295 | + |
| 296 | + @overload |
| 297 | + def dates(self, start_date, end_date) -> DatetimeIndex: ... |
| 298 | + |
284 | 299 | def dates(
|
285 | 300 | self, start_date, end_date, return_name: bool = False
|
286 | 301 | ) -> Series | DatetimeIndex:
|
@@ -411,7 +426,7 @@ def _apply_rule(self, dates: DatetimeIndex) -> DatetimeIndex:
|
411 | 426 | return dates
|
412 | 427 |
|
413 | 428 |
|
414 |
| -holiday_calendars = {} |
| 429 | +holiday_calendars: dict[str, type[AbstractHolidayCalendar]] = {} |
415 | 430 |
|
416 | 431 |
|
417 | 432 | def register(cls) -> None:
|
@@ -449,7 +464,7 @@ class AbstractHolidayCalendar(metaclass=HolidayCalendarMetaClass):
|
449 | 464 | rules: list[Holiday] = []
|
450 | 465 | start_date = Timestamp(datetime(1970, 1, 1))
|
451 | 466 | end_date = Timestamp(datetime(2200, 12, 31))
|
452 |
| - _cache = None |
| 467 | + _cache: tuple[Timestamp, Timestamp, Series] | None = None |
453 | 468 |
|
454 | 469 | def __init__(self, name: str = "", rules=None) -> None:
|
455 | 470 | """
|
@@ -478,7 +493,9 @@ def rule_from_name(self, name: str) -> Holiday | None:
|
478 | 493 |
|
479 | 494 | return None
|
480 | 495 |
|
481 |
| - def holidays(self, start=None, end=None, return_name: bool = False): |
| 496 | + def holidays( |
| 497 | + self, start=None, end=None, return_name: bool = False |
| 498 | + ) -> DatetimeIndex | Series: |
482 | 499 | """
|
483 | 500 | Returns a curve with holidays between start_date and end_date
|
484 | 501 |
|
@@ -515,14 +532,9 @@ def holidays(self, start=None, end=None, return_name: bool = False):
|
515 | 532 | rule.dates(start, end, return_name=True) for rule in self.rules
|
516 | 533 | ]
|
517 | 534 | if pre_holidays:
|
518 |
| - # error: Argument 1 to "concat" has incompatible type |
519 |
| - # "List[Union[Series, DatetimeIndex]]"; expected |
520 |
| - # "Union[Iterable[DataFrame], Mapping[<nothing>, DataFrame]]" |
521 |
| - holidays = concat(pre_holidays) # type: ignore[arg-type] |
| 535 | + holidays = concat(pre_holidays) |
522 | 536 | else:
|
523 |
| - # error: Incompatible types in assignment (expression has type |
524 |
| - # "Series", variable has type "DataFrame") |
525 |
| - holidays = Series(index=DatetimeIndex([]), dtype=object) # type: ignore[assignment] |
| 537 | + holidays = Series(index=DatetimeIndex([]), dtype=object) |
526 | 538 |
|
527 | 539 | self._cache = (start, end, holidays.sort_index())
|
528 | 540 |
|
|
0 commit comments