Skip to content

Commit 0f0df9c

Browse files
Merge pull request #16 from nsidc/ardwa-4
ARDWA-4
2 parents 16947d2 + 0505031 commit 0f0df9c

File tree

4 files changed

+150
-6
lines changed

4 files changed

+150
-6
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ dependencies = [
2525
"loguru",
2626
"tqdm",
2727
"fastapi ~=0.111.0",
28+
"matplotlib",
29+
"numpy",
30+
"pandas",
2831
"pydantic ~=2.0",
2932
"pydantic-settings",
3033
"sqlalchemy ~=2.0",
@@ -61,8 +64,6 @@ docs = [
6164
ui = [
6265
"jupyterlab",
6366
"leafmap",
64-
"matplotlib",
65-
"pandas",
6667
]
6768

6869
[project.urls]

src/aross_stations_db/api/v1/output.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime as dt
2+
import io
23
from typing import Annotated
34

45
from annotated_types import Ge, Le
@@ -8,8 +9,12 @@
89
Point,
910
)
1011
from geojson_pydantic.types import Position2D
12+
import matplotlib.pyplot as plt
13+
import numpy as np
14+
import pandas as pd
1115
from pydantic import BaseModel
1216
from sqlalchemy import Row
17+
from sqlalchemy.orm.query import RowReturningQuery
1318

1419
from aross_stations_db.db.tables import Station
1520

@@ -64,6 +69,87 @@ def timeseries_query_results_to_json(
6469
]
6570

6671

72+
def timeseries_query_results_to_bar_plot_buffer(
73+
query: RowReturningQuery[tuple[dt.datetime, int]],
74+
start: dt.date,
75+
end: dt.date,
76+
format: str = 'png'
77+
) -> io.BytesIO:
78+
data = pd.read_sql(query.statement, query.session.connection())
79+
data.set_index('month', inplace=True)
80+
81+
start_str = start.strftime("%Y-%m")
82+
end_str = end.strftime("%Y-%m")
83+
84+
title_parts = [
85+
f"Monthly Rain-on-Snow Events",
86+
f"[{start_str} - {end_str}]"
87+
]
88+
title = "\n".join(title_parts)
89+
90+
if len(data) == 0:
91+
data.loc[start_str] = 0
92+
data.loc[end_str] = 0
93+
94+
95+
data.index = pd.to_datetime(data.index)
96+
data.index = data.index.strftime("%Y-%m")
97+
data = add_missing_plot_months(data)
98+
99+
plot = data.plot(
100+
kind="bar",
101+
title=title,
102+
ylabel="Event Count",
103+
xlabel="Month",
104+
rot=45,
105+
legend=False,
106+
)
107+
108+
ticks = plot.get_xticklines()
109+
labels = plot.get_xticklabels()
110+
111+
# Evenly space out the tick labels to avoid overcrowding
112+
indexes = np.linspace(0, len(labels)-1, num=12, dtype=int)
113+
for i, l in enumerate(labels):
114+
if i not in indexes:
115+
l.set_visible(False)
116+
else:
117+
l.set_horizontalalignment('right')
118+
l.set_rotation_mode('anchor')
119+
ticks[i*2].set_markersize(6)
120+
121+
plt.tight_layout()
122+
123+
# Create the buffer and return it so it can be sent to the requester
124+
buffer = io.BytesIO()
125+
plt.savefig(buffer, format="png")
126+
plt.close()
127+
128+
buffer.seek(0)
129+
130+
return buffer
131+
132+
133+
# If there are any months missing in the dataframe, add a "count 0" entry for them
134+
def add_missing_plot_months(df: pd.DataFrame) -> pd.DataFrame:
135+
start = df.index[0]
136+
end = df.index[-1]
137+
138+
syear, smonth = map(int, start.split('-'))
139+
eyear, emonth = map(int, end.split('-'))
140+
141+
while [syear, smonth] != [eyear, emonth]:
142+
smonth += 1
143+
if smonth > 12:
144+
smonth = 1
145+
syear += 1
146+
key = f"{syear:04}-{smonth:02}"
147+
if key not in df.index:
148+
df.loc[key] = [0]
149+
150+
return df.sort_index()
151+
152+
67153
class ClimatologyJsonElement(BaseModel):
68154
month: Annotated[int, Ge(1), Le(12)]
69155
event_count: int
Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
import datetime as dt
22
from typing import Annotated
33

4-
from fastapi import APIRouter, Depends, Query
4+
from fastapi import APIRouter, Depends, Form, Query
5+
from fastapi.responses import StreamingResponse
56
from geoalchemy2 import WKTElement
7+
import pandas as pd
68
from sqlalchemy.orm import Session
79

810
from aross_stations_db.api.dependencies import get_db_session
911
from aross_stations_db.api.v1.output import (
1012
TimeseriesJsonElement,
1113
timeseries_query_results_to_json,
14+
timeseries_query_results_to_bar_plot_buffer,
1215
)
1316
from aross_stations_db.db.query import timeseries_query
1417

15-
router = APIRouter()
18+
from loguru import logger
1619

20+
router = APIRouter()
1721

1822
@router.get("/monthly")
1923
def get_monthly_timeseries(
@@ -22,8 +26,58 @@ def get_monthly_timeseries(
2226
start: Annotated[dt.datetime, Query(description="ISO-format timestamp")],
2327
end: Annotated[dt.datetime, Query(description="ISO-format timestamp")],
2428
polygon: Annotated[str | None, WKTElement, Query(description="WKT shape")] = None,
29+
stations: Annotated[list[str], Query(description="List of station identifiers")] = [],
2530
) -> list[TimeseriesJsonElement]:
2631
"""Get a monthly timeseries of events matching query parameters."""
27-
query = timeseries_query(db=db, start=start, end=end, polygon=polygon)
32+
logger.debug(f"STATIONS: {stations}")
33+
query = timeseries_query(db=db, start=start, end=end, polygon=polygon, stations=stations)
2834

2935
return timeseries_query_results_to_json(query.all())
36+
37+
38+
@router.post("/monthly")
39+
def post_monthly_timeseries(
40+
db: Annotated[Session, Depends(get_db_session)],
41+
*,
42+
start: Annotated[dt.datetime, Form(description="ISO-format timestamp")],
43+
end: Annotated[dt.datetime, Form(description="ISO-format timestamp")],
44+
polygon: Annotated[str | None, WKTElement, Form(description="WKT shape")] = None,
45+
stations: Annotated[list[str], Form(description="List of station identifiers")] = [],
46+
) -> list[TimeseriesJsonElement]:
47+
"""Get a monthly timeseries of events matching query parameters."""
48+
logger.debug(f"STATIONS: {stations}")
49+
query = timeseries_query(db=db, start=start, end=end, polygon=polygon, stations=stations)
50+
51+
return timeseries_query_results_to_json(query.all())
52+
53+
54+
@router.get("/monthly/png")
55+
def get_monthly_timeseries_png(
56+
db: Annotated[Session, Depends(get_db_session)],
57+
*,
58+
start: Annotated[dt.datetime, Query(description="ISO-format timestamp")],
59+
end: Annotated[dt.datetime, Query(description="ISO-format timestamp")],
60+
polygon: Annotated[str | None, WKTElement, Query(description="WKT shape")] = None,
61+
stations: Annotated[list[str], Query(Description="List of station identifiers")] = [],
62+
) -> StreamingResponse:
63+
"""Get a monthly timeseries image plot of events matching query parameters."""
64+
query = timeseries_query(db=db, start=start, end=end, polygon=polygon, stations=stations)
65+
buffer = timeseries_query_results_to_bar_plot_buffer(query, start, end, 'png')
66+
67+
return StreamingResponse(buffer, media_type="image/png")
68+
69+
70+
@router.post("/monthly/png")
71+
def post_monthly_timeseries_png(
72+
db: Annotated[Session, Depends(get_db_session)],
73+
*,
74+
start: Annotated[dt.datetime, Form(description="ISO-format timestamp")],
75+
end: Annotated[dt.datetime, Form(description="ISO-format timestamp")],
76+
polygon: Annotated[str | None, WKTElement, Form(description="WKT shape")] = None,
77+
stations: Annotated[list[str], Form(Description="List of station identifiers")] = [],
78+
) -> StreamingResponse:
79+
"""Get a monthly timeseries image plot of events matching query parameters."""
80+
query = timeseries_query(db=db, start=start, end=end, polygon=polygon, stations=stations)
81+
buffer = timeseries_query_results_to_bar_plot_buffer(query, start, end, 'png')
82+
83+
return StreamingResponse(buffer, media_type="image/png")

src/aross_stations_db/db/query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def station_data_query(
4646
end: dt.datetime | None,
4747
stations: list[str] = [],
4848
) -> RowReturningQuery[tuple[str, dt.datetime, dt.datetime, bool | None, int, int, int, int]]:
49-
print(type(stations))
5049
query = db.query(
5150
Event.station_id,
5251
Event.time_start,
@@ -72,6 +71,7 @@ def timeseries_query(
7271
start: dt.datetime,
7372
end: dt.datetime,
7473
polygon: str | None = None,
74+
stations: list[str] | None = None,
7575
) -> RowReturningQuery[tuple[dt.datetime, int]]:
7676
query = db.query(
7777
func.date_trunc("month", Event.time_start, type_=DateTime).label("month"),
@@ -87,6 +87,9 @@ def timeseries_query(
8787
)
8888
)
8989

90+
if stations:
91+
query = query.filter(Event.station_id.in_(stations))
92+
9093
return query.group_by("month").order_by("month")
9194

9295

0 commit comments

Comments
 (0)