Skip to content

Commit 3170243

Browse files
authored
Merge pull request #105 from Point72/pit/datetime
Add DatetimeContext and allow datetimes to be passed to DateContext
2 parents 0c32044 + 2219de6 commit 3170243

File tree

3 files changed

+52
-23
lines changed

3 files changed

+52
-23
lines changed

ccflow/context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"NullContext",
1414
"GenericContext",
1515
"DateContext",
16+
"DatetimeContext",
1617
"EntryTimeContext",
1718
"DateRangeContext",
1819
"VersionedDateContext",
@@ -96,6 +97,19 @@ def _date_context_validator(cls, v, handler, info):
9697
return handler(v)
9798

9899

100+
class DatetimeContext(ContextBase):
101+
dt: datetime
102+
103+
@model_validator(mode="wrap")
104+
def _datetime_context_validator(cls, v, handler, info):
105+
if cls is DatetimeContext and not isinstance(v, (DatetimeContext, dict)):
106+
if isinstance(v, (tuple, list)) and len(v) == 1:
107+
v = v[0]
108+
109+
v = DatetimeContext(dt=v)
110+
return handler(v)
111+
112+
99113
class EntryTimeContext(ContextBase):
100114
entry_time_cutoff: Optional[datetime] = None
101115

ccflow/tests/test_context.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from datetime import date, timedelta
1+
from datetime import date, datetime, timedelta, timezone
22
from unittest import TestCase
33

44
import pandas as pd
5-
from pydantic import BaseModel, ValidationError
5+
from pydantic import TypeAdapter, ValidationError
66

77
from ccflow.context import (
88
DateContext,
99
DateRangeContext,
10+
DatetimeContext,
1011
FreqContext,
1112
FreqDateContext,
1213
FreqDateRangeContext,
@@ -24,14 +25,6 @@
2425
from ccflow.result import GenericResult
2526

2627

27-
class MyModel(BaseModel):
28-
context: DateContext
29-
30-
31-
class MyRangeModel(BaseModel):
32-
context: DateRangeContext
33-
34-
3528
class TestContexts(TestCase):
3629
def test_null_context(self):
3730
n1 = NullContext()
@@ -54,13 +47,32 @@ def test_date_validation(self):
5447
self.assertEqual(DateContext(date="-1d"), c1)
5548
self.assertRaises(ValueError, DateContext, date="foo")
5649

57-
# Test coercion to DateContext on nested models
58-
self.assertEqual(MyModel(context={"date": date.today()}).context, c)
59-
self.assertEqual(MyModel(context=date.today()).context, c)
60-
self.assertEqual(MyModel(context=str(date.today())).context, c)
61-
self.assertEqual(MyModel(context="0d").context, c)
62-
self.assertEqual(MyModel(context="-1d").context, c1)
63-
self.assertRaises(ValueError, MyModel, context="foo")
50+
# Test validation from other types
51+
self.assertEqual(TypeAdapter(DateContext).validate_python({"date": date.today()}), c)
52+
self.assertEqual(TypeAdapter(DateContext).validate_python(date.today()), c)
53+
self.assertEqual(TypeAdapter(DateContext).validate_python([date.today()]), c)
54+
self.assertEqual(TypeAdapter(DateContext).validate_python(str(date.today())), c)
55+
self.assertEqual(TypeAdapter(DateContext).validate_python("0d"), c)
56+
self.assertEqual(TypeAdapter(DateContext).validate_python("-1d"), c1)
57+
self.assertRaises(ValueError, TypeAdapter(DateContext).validate_python, "foo")
58+
59+
# Test validation from datetime (not normally allowed by pydantic)
60+
dt = datetime(2022, 1, 1, 12, tzinfo=timezone.utc)
61+
self.assertEqual(TypeAdapter(DateContext).validate_python(dt), DateContext(date=dt.date()))
62+
63+
def test_datetime_validation(self):
64+
dt = datetime(2022, 1, 1, 12, 0, tzinfo=timezone.utc)
65+
c = DatetimeContext(dt=dt)
66+
self.assertEqual(DatetimeContext(dt=str(dt)), c)
67+
self.assertEqual(DatetimeContext(dt=dt.date()), DatetimeContext(dt=datetime(2022, 1, 1)))
68+
69+
# Test validation from other types
70+
self.assertEqual(TypeAdapter(DatetimeContext).validate_python({"dt": dt}), c)
71+
self.assertEqual(TypeAdapter(DatetimeContext).validate_python(dt), c)
72+
self.assertEqual(TypeAdapter(DatetimeContext).validate_python([dt]), c)
73+
self.assertEqual(TypeAdapter(DatetimeContext).validate_python(str(dt)), c)
74+
self.assertEqual(TypeAdapter(DatetimeContext).validate_python(dt.isoformat()), c)
75+
self.assertRaises(ValueError, TypeAdapter(DatetimeContext).validate_python, "foo")
6476

6577
def test_coercion(self):
6678
d = DateContext(date=date(2022, 1, 1))
@@ -76,10 +88,11 @@ def test_date_range(self):
7688
self.assertEqual(DateRangeContext(start_date="-1d", end_date="0d"), c)
7789
self.assertRaises(ValueError, DateRangeContext, start_date="foo", end_date=d1)
7890

79-
# Test coercion to DateContext on nested models
80-
self.assertEqual(MyRangeModel(context={"start_date": d0, "end_date": d1}).context, c)
81-
self.assertEqual(MyRangeModel(context=("-1d", "0d")).context, c)
82-
self.assertEqual(MyRangeModel(context=["-1d", "0d"]).context, c)
91+
# Test validation from other types
92+
self.assertEqual(TypeAdapter(DateRangeContext).validate_python({"start_date": d0, "end_date": d1}), c)
93+
self.assertEqual(TypeAdapter(DateRangeContext).validate_python(("-1d", "0d")), c)
94+
self.assertEqual(TypeAdapter(DateRangeContext).validate_python(["-1d", "0d"]), c)
95+
self.assertEqual(TypeAdapter(DateRangeContext).validate_python(["-1d", datetime.now()]), c)
8396

8497
def test_freq(self):
8598
self.assertEqual(

ccflow/validators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""This module contains common validators."""
22

33
import logging
4-
from datetime import date
4+
from datetime import date, datetime
55
from typing import Any, Dict, Optional
66

77
import pandas as pd
@@ -11,13 +11,15 @@
1111

1212

1313
def normalize_date(v: Any) -> Any:
14-
"""Validator that will convert string offsets to date based on today."""
14+
"""Validator that will convert string offsets to date based on today, and convert datetime to date."""
1515
if isinstance(v, str): # Check case where it's an offset
1616
try:
1717
timestamp = pd.tseries.frequencies.to_offset(v) + date.today()
1818
return timestamp.date()
1919
except ValueError:
2020
pass
21+
if isinstance(v, datetime):
22+
return v.date()
2123
return v
2224

2325

0 commit comments

Comments
 (0)