Skip to content

Commit 692fded

Browse files
authored
Merge pull request #10 from erwindouna/add-date-object
Add date object
2 parents 24517eb + 7c803d2 commit 692fded

File tree

2 files changed

+52
-17
lines changed

2 files changed

+52
-17
lines changed

src/pyfirefly/pyfirefly.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import socket
77
from dataclasses import dataclass
8+
from datetime import datetime
89
from importlib import metadata
910
from typing import Any, Self
1011
from urllib.parse import urlparse
@@ -151,6 +152,20 @@ async def _request(
151152

152153
return await response.json()
153154

155+
def _format_date(self, date_value: datetime | str) -> str:
156+
"""Format a date value to a string in 'YYYY-MM-DD' format.
157+
158+
Args:
159+
date_value: A date object or a string representing a date.
160+
161+
Returns:
162+
A string formatted as 'YYYY-MM-DD'.
163+
164+
"""
165+
if isinstance(date_value, datetime):
166+
return date_value.strftime("%Y-%m-%d")
167+
return date_value
168+
154169
async def get_about(self) -> About:
155170
"""Get information about the Firefly server.
156171
@@ -194,8 +209,8 @@ async def get_accounts(self) -> list[Account]:
194209
async def get_transactions(
195210
self,
196211
account_id: int | None = None,
197-
start: str | None = None,
198-
end: str | None = None,
212+
start: datetime | None = None,
213+
end: datetime | None = None,
199214
) -> list[Transaction]:
200215
"""Get transactions for a specific account. Else, return all transactions.
201216
@@ -220,9 +235,9 @@ async def get_transactions(
220235
while next_page:
221236
params: dict[str, str] = {"page": str(next_page)}
222237
if start:
223-
params["start"] = start
238+
params["start"] = self._format_date(start)
224239
if end:
225-
params["end"] = end
240+
params["end"] = self._format_date(end)
226241

227242
response = await self._request(
228243
uri=uri,
@@ -268,7 +283,12 @@ async def get_categories(self) -> list[Category]:
268283

269284
return [Category.from_dict(cat) for cat in categories]
270285

271-
async def get_category(self, category_id: int, start: str | None = None, end: str | None = None) -> Category:
286+
async def get_category(
287+
self,
288+
category_id: int,
289+
start: datetime | None = None,
290+
end: datetime | None = None,
291+
) -> Category:
272292
"""Get a specific category by its ID.
273293
274294
Args:
@@ -282,14 +302,14 @@ async def get_category(self, category_id: int, start: str | None = None, end: st
282302
"""
283303
params: dict[str, str] = {}
284304
if start:
285-
params["start"] = start
305+
params["start"] = self._format_date(start)
286306
if end:
287-
params["end"] = end
307+
params["end"] = self._format_date(end)
288308

289309
category = await self._request(uri=f"categories/{category_id}", params=params)
290310
return Category.from_dict(category["data"])
291311

292-
async def get_budgets(self, start: str | None = None, end: str | None = None) -> list[Budget]:
312+
async def get_budgets(self, start: datetime | None = None, end: datetime | None = None) -> list[Budget]:
293313
"""Get budgets for the Firefly server. Both start and end dates are required for date range filtering.
294314
295315
Args:
@@ -302,13 +322,13 @@ async def get_budgets(self, start: str | None = None, end: str | None = None) ->
302322
"""
303323
params: dict[str, str] = {}
304324
if start and end:
305-
params["start"] = start
306-
params["end"] = end
325+
params["start"] = self._format_date(start)
326+
params["end"] = self._format_date(end)
307327

308328
budgets = await self._request(uri="budgets", params=params)
309329
return [Budget.from_dict(budget) for budget in budgets["data"]]
310330

311-
async def get_bills(self, start: str | None = None, end: str | None = None) -> list[Bill]:
331+
async def get_bills(self, start: datetime | None = None, end: datetime | None = None) -> list[Bill]:
312332
"""Get bills for the Firefly server. Both start and end dates are required for date range filtering.
313333
314334
Args:
@@ -323,8 +343,8 @@ async def get_bills(self, start: str | None = None, end: str | None = None) -> l
323343
next_page: int | None = 1
324344
params: dict[str, str] = {"page": str(next_page)}
325345
if start and end:
326-
params["start"] = start
327-
params["end"] = end
346+
params["start"] = self._format_date(start)
347+
params["end"] = self._format_date(end)
328348

329349
while next_page:
330350
response = await self._request(uri="bills", params=params)

tests/test_models.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from datetime import datetime, timezone
56
from typing import TYPE_CHECKING
67

78
from aresponses import ResponsesMockServer
@@ -73,7 +74,11 @@ async def test_account_transactions_model(
7374
),
7475
)
7576

76-
transactions = await firefly_client.get_transactions(account_id=1, start="2025-01-01", end="2025-12-31")
77+
transactions = await firefly_client.get_transactions(
78+
account_id=1,
79+
start=datetime(2025, 1, 1, tzinfo=timezone.utc),
80+
end=datetime(2025, 12, 31, tzinfo=timezone.utc),
81+
)
7782
assert transactions == snapshot
7883

7984
# Now for all transactions
@@ -109,7 +114,11 @@ async def test_category_model(
109114
),
110115
)
111116

112-
category = await firefly_client.get_category(category_id=1, start="2025-01-01", end="2025-12-31")
117+
category = await firefly_client.get_category(
118+
category_id=1,
119+
start=datetime(2025, 1, 1, tzinfo=timezone.utc),
120+
end=datetime(2025, 12, 31, tzinfo=timezone.utc),
121+
)
113122
assert category == snapshot
114123

115124
# Now without a date range
@@ -166,7 +175,10 @@ async def test_budgets_model(
166175
),
167176
)
168177

169-
budgets = await firefly_client.get_budgets(start="2025-01-01", end="2025-12-31")
178+
budgets = await firefly_client.get_budgets(
179+
start=datetime(2025, 1, 1, tzinfo=timezone.utc),
180+
end=datetime(2025, 12, 31, tzinfo=timezone.utc),
181+
)
170182
assert budgets == snapshot
171183

172184
# Now without a date range
@@ -202,7 +214,10 @@ async def test_bills_model(
202214
),
203215
)
204216

205-
bills = await firefly_client.get_bills(start="2025-01-01", end="2025-12-31")
217+
bills = await firefly_client.get_bills(
218+
start=datetime(2025, 1, 1, tzinfo=timezone.utc),
219+
end=datetime(2025, 12, 31, tzinfo=timezone.utc),
220+
)
206221
assert bills == snapshot
207222

208223
# Now without a date range

0 commit comments

Comments
 (0)