Skip to content

Commit 0c1f1a5

Browse files
authored
Add EndpointPayload ABC and Specific Payload Dataclasses for API Requests (#208)
This PR adds 'payload' dataclasses for the API endpoints. The EndpointPayload class is an Abstract Base Class, which defines the structure for all other payload classes. Payload classes are in charge of normalizing and validating the payload before it is used for a POST request. They also implement a .to_dict property for easy access to the correctly structured payload data as a python dictionary.
1 parent 8a4b2a1 commit 0c1f1a5

File tree

4 files changed

+362
-6
lines changed

4 files changed

+362
-6
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from abc import ABC
2+
from abc import abstractmethod
3+
from dataclasses import dataclass
4+
from dataclasses import field
5+
from enum import Enum
6+
from typing import Optional
7+
8+
from datacommons_client.utils.error_handling import (
9+
InvalidObservationSelectError,
10+
)
11+
12+
13+
@dataclass
14+
class EndpointRequestPayload(ABC):
15+
"""
16+
Abstract base class for payload dataclasses.
17+
Defines the required interface for all payload dataclasses.
18+
"""
19+
20+
@abstractmethod
21+
def normalize(self) -> None:
22+
"""Normalize the payload for consistent internal representation."""
23+
pass
24+
25+
@abstractmethod
26+
def validate(self) -> None:
27+
"""Validate the payload to ensure its structure and contents are correct."""
28+
pass
29+
30+
@property
31+
@abstractmethod
32+
def to_dict(self) -> dict:
33+
"""Convert the payload into a dictionary format for API requests."""
34+
pass
35+
36+
37+
@dataclass
38+
class NodeRequestPayload(EndpointRequestPayload):
39+
"""
40+
A dataclass to structure, normalize, and validate the payload for a Node V2 API request.
41+
42+
Attributes:
43+
node_dcids (str | list[str]): The DCID(s) of the nodes to query.
44+
expression (str): The property or relation expression(s) to query.
45+
"""
46+
47+
node_dcids: str | list[str]
48+
expression: str
49+
50+
def __post_init__(self):
51+
self.normalize()
52+
self.validate()
53+
54+
def normalize(self):
55+
if isinstance(self.node_dcids, str):
56+
self.node_dcids = [self.node_dcids]
57+
58+
def validate(self):
59+
if not isinstance(self.expression, str):
60+
raise ValueError("Expression must be a string.")
61+
62+
@property
63+
def to_dict(self) -> dict:
64+
return {"nodes": self.node_dcids, "property": self.expression}
65+
66+
67+
class ObservationSelect(str, Enum):
68+
DATE = "date"
69+
VARIABLE = "variable"
70+
ENTITY = "entity"
71+
VALUE = "value"
72+
73+
@classmethod
74+
def _missing_(cls, value):
75+
"""Handle missing enum values by raising a custom error."""
76+
valid_values = [member.value for member in cls]
77+
message = f"Invalid `select` field: '{value}'. Only {', '.join(valid_values)} are allowed."
78+
raise InvalidObservationSelectError(message=message)
79+
80+
81+
class ObservationDate(str, Enum):
82+
LATEST = "LATEST"
83+
ALL = ""
84+
85+
86+
@dataclass
87+
class ObservationRequestPayload(EndpointRequestPayload):
88+
"""
89+
A dataclass to structure, normalize, and validate the payload for an Observation V2 API request.
90+
91+
Attributes:
92+
date (str): The date for which data is being requested.
93+
variable_dcids (str | list[str]): One or more variable IDs for the data.
94+
select (list[ObservationSelect]): Fields to include in the response.
95+
Defaults to ["date", "variable", "entity", "value"].
96+
entity_dcids (Optional[str | list[str]]): One or more entity IDs to filter the data.
97+
entity_expression (Optional[str]): A string expression to filter entities.
98+
"""
99+
100+
date: ObservationDate | str = ""
101+
variable_dcids: str | list[str] = field(default_factory=list)
102+
select: list[ObservationSelect | str] = field(
103+
default_factory=lambda: [
104+
ObservationSelect.DATE,
105+
ObservationSelect.VARIABLE,
106+
ObservationSelect.ENTITY,
107+
ObservationSelect.VALUE,
108+
]
109+
)
110+
entity_dcids: Optional[str | list[str]] = None
111+
entity_expression: Optional[str] = None
112+
113+
def __post_init__(self):
114+
"""
115+
Initializes the payload, performing validation and normalization.
116+
117+
Raises:
118+
ValueError: If validation rules are violated.
119+
"""
120+
self.RequiredSelect = {"variable", "entity"}
121+
self.normalize()
122+
self.validate()
123+
124+
def normalize(self):
125+
"""
126+
Normalizes the payload for consistent internal representation.
127+
128+
- Converts `variable_dcids` and `entity_dcids` to lists if they are passed as strings.
129+
- Normalizes the `date` field to ensure it is in the correct format.
130+
"""
131+
# Normalize variable
132+
if isinstance(self.variable_dcids, str):
133+
self.variable_dcids = [self.variable_dcids]
134+
135+
# Normalize entity
136+
if isinstance(self.entity_dcids, str):
137+
self.entity_dcids = [self.entity_dcids]
138+
139+
# Normalize date field
140+
if self.date.upper() == "ALL":
141+
self.date = ObservationDate.ALL
142+
elif (self.date.upper() == "LATEST") or (self.date == ""):
143+
self.date = ObservationDate.LATEST
144+
145+
def validate(self):
146+
"""
147+
Validates the payload to ensure consistency and correctness.
148+
149+
Raises:
150+
ValueError: If both `entity_dcids` and `entity_expression` are set,
151+
if neither is set, or if required fields are missing from `select`.
152+
"""
153+
154+
# Validate mutually exclusive entity fields
155+
if bool(self.entity_dcids) == bool(self.entity_expression):
156+
raise ValueError(
157+
"Exactly one of 'entity_dcids' or 'entity_expression' must be set."
158+
)
159+
160+
# Check if all required fields are present
161+
missing_fields = self.RequiredSelect - set(self.select)
162+
if missing_fields:
163+
raise ValueError(
164+
f"The 'select' field must include at least the following: {', '.join(self.RequiredSelect)} "
165+
f"(missing: {', '.join(missing_fields)})"
166+
)
167+
168+
# Check all select fields are valid
169+
[ObservationSelect(select_field) for select_field in self.select]
170+
171+
@property
172+
def to_dict(self) -> dict:
173+
"""
174+
Converts the payload into a dictionary format for API requests.
175+
176+
Returns:
177+
dict: The normalized and validated payload.
178+
"""
179+
return {
180+
"date": self.date,
181+
"variable": {"dcids": self.variable_dcids},
182+
"entity": (
183+
{"dcids": self.entity_dcids}
184+
if self.entity_dcids
185+
else {"expression": self.entity_expression}
186+
),
187+
"select": self.select,
188+
}
189+
190+
191+
@dataclass
192+
class ResolveRequestPayload(EndpointRequestPayload):
193+
"""
194+
A dataclass to structure, normalize, and validate the payload for a Resolve V2 API request.
195+
196+
Attributes:
197+
node_dcids (str | list[str]): The DCID(s) of the nodes to query.
198+
expression (str): The relation expression to query.
199+
"""
200+
201+
node_dcids: str | list[str]
202+
expression: str
203+
204+
def __post_init__(self):
205+
self.normalize()
206+
self.validate()
207+
208+
def normalize(self):
209+
if isinstance(self.node_dcids, str):
210+
self.node_dcids = [self.node_dcids]
211+
212+
def validate(self):
213+
if not isinstance(self.expression, str):
214+
raise ValueError("Expression must be a string.")
215+
216+
@property
217+
def to_dict(self) -> dict:
218+
return {"nodes": self.node_dcids, "property": self.expression}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from datacommons_client.endpoints.payloads import NodeRequestPayload
2+
from datacommons_client.endpoints.payloads import ObservationDate
3+
from datacommons_client.endpoints.payloads import ObservationRequestPayload
4+
from datacommons_client.endpoints.payloads import ObservationSelect
5+
from datacommons_client.endpoints.payloads import ResolveRequestPayload
6+
from datacommons_client.utils.error_handling import (
7+
InvalidObservationSelectError,
8+
)
9+
import pytest
10+
11+
12+
def test_node_payload_normalize():
13+
"""Tests that NodeRequestPayload correctly normalizes single and multiple node_dcids."""
14+
payload = NodeRequestPayload(node_dcids="node1", expression="prop1")
15+
assert payload.node_dcids == ["node1"]
16+
17+
payload = NodeRequestPayload(node_dcids=["node1", "node2"], expression="prop1")
18+
assert payload.node_dcids == ["node1", "node2"]
19+
20+
21+
def test_node_payload_validate():
22+
"""Tests that NodeRequestPayload validates its inputs correctly."""
23+
with pytest.raises(ValueError):
24+
NodeRequestPayload(
25+
node_dcids="node1", expression=123
26+
) # `expression` must be a string
27+
28+
29+
def test_node_payload_to_dict():
30+
"""Tests NodeRequestPayload conversion to dictionary."""
31+
payload = NodeRequestPayload(node_dcids="node1", expression="prop1")
32+
assert payload.to_dict == {"nodes": ["node1"], "property": "prop1"}
33+
34+
35+
def test_observation_payload_normalize():
36+
"""Tests that ObservationRequestPayload normalizes inputs correctly."""
37+
payload = ObservationRequestPayload(
38+
date="LATEST",
39+
variable_dcids="var1",
40+
select=["variable", "entity"],
41+
entity_dcids="ent1",
42+
)
43+
assert payload.variable_dcids == ["var1"]
44+
assert payload.entity_dcids == ["ent1"]
45+
assert payload.date == ObservationDate.LATEST
46+
47+
payload = ObservationRequestPayload(
48+
date="all",
49+
variable_dcids=["var1"],
50+
select=["variable", "entity"],
51+
entity_dcids=["ent1"],
52+
)
53+
assert payload.date == ObservationDate.ALL
54+
assert payload.variable_dcids == ["var1"]
55+
assert payload.entity_dcids == ["ent1"]
56+
57+
58+
def test_observation_select_invalid_value():
59+
"""Tests that an invalid ObservationSelect value raises InvalidObservationSelectError."""
60+
with pytest.raises(
61+
InvalidObservationSelectError,
62+
match=r"Invalid `select` field: 'invalid'. Only date, variable, entity, value are allowed.",
63+
):
64+
ObservationSelect("invalid")
65+
66+
67+
def test_observation_payload_validate():
68+
"""Tests that ObservationRequestPayload validates its inputs."""
69+
with pytest.raises(ValueError):
70+
ObservationRequestPayload(
71+
date="LATEST",
72+
variable_dcids="var1",
73+
select=["variable"],
74+
entity_dcids=None,
75+
entity_expression=None,
76+
) # Requires either `entity_dcids` or `entity_expression`
77+
78+
with pytest.raises(ValueError):
79+
ObservationRequestPayload(
80+
date="LATEST",
81+
variable_dcids="var1",
82+
select=["value"], # Missing required "variable" and "entity"
83+
entity_expression="expression",
84+
)
85+
86+
with pytest.raises(ValueError):
87+
ObservationRequestPayload(
88+
date="LATEST",
89+
variable_dcids="var1",
90+
select=["variable", "entity"],
91+
entity_dcids="ent1",
92+
entity_expression="expression", # Both `entity_dcids` and `entity_expression` set
93+
)
94+
95+
96+
def test_observation_payload_to_dict():
97+
"""Tests ObservationRequestPayload conversion to dictionary."""
98+
payload = ObservationRequestPayload(
99+
date="LATEST",
100+
variable_dcids="var1",
101+
select=["variable", "entity"],
102+
entity_dcids="ent1",
103+
)
104+
assert payload.to_dict == {
105+
"date": ObservationDate.LATEST,
106+
"variable": {"dcids": ["var1"]},
107+
"entity": {"dcids": ["ent1"]},
108+
"select": ["variable", "entity"],
109+
}
110+
111+
112+
def test_resolve_payload_normalize():
113+
"""Tests that ResolveRequestPayload normalizes single and multiple node_dcids."""
114+
payload = ResolveRequestPayload(node_dcids="node1", expression="expr1")
115+
assert payload.node_dcids == ["node1"]
116+
117+
payload = ResolveRequestPayload(
118+
node_dcids=["node1", "node2"], expression="expr1"
119+
)
120+
assert payload.node_dcids == ["node1", "node2"]
121+
122+
123+
def test_resolve_payload_validate():
124+
"""Tests that ResolveRequestPayload validates its inputs correctly."""
125+
with pytest.raises(ValueError):
126+
ResolveRequestPayload(
127+
node_dcids="node1", expression=123
128+
) # `expression` must be a string
129+
130+
131+
def test_resolve_payload_to_dict():
132+
"""Tests ResolveRequestPayload conversion to dictionary."""
133+
payload = ResolveRequestPayload(node_dcids="node1", expression="expr1")
134+
assert payload.to_dict == {"nodes": ["node1"], "property": "expr1"}

datacommons_client/tests/endpoints/test_request_handling.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
import pytest
55
import requests
66

7-
from datacommons_client.utils.error_hanlding import (
8-
DCAuthenticationError,
9-
DCConnectionError,
10-
DCStatusError,
11-
APIError,
12-
)
7+
from datacommons_client.utils.error_hanlding import APIError
8+
from datacommons_client.utils.error_hanlding import DCAuthenticationError
9+
from datacommons_client.utils.error_hanlding import DCConnectionError
10+
from datacommons_client.utils.error_hanlding import DCStatusError
1311
from datacommons_client.utils.error_hanlding import InvalidDCInstanceError
1412
from datacommons_client.utils.request_handling import _check_instance_is_valid
1513
from datacommons_client.utils.request_handling import _fetch_with_pagination

datacommons_client/utils/error_handling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,9 @@ class InvalidDCInstanceError(DataCommonsError):
7676
"""Raised when an invalid Data Commons instance is provided."""
7777

7878
default_message = "The specified Data Commons instance is invalid."
79+
80+
81+
class InvalidObservationSelectError(DataCommonsError):
82+
"""Raised when an invalid ObservationSelect field is provided."""
83+
84+
default_message = "The ObservationSelect field is invalid."

0 commit comments

Comments
 (0)