Skip to content

Commit 26d9115

Browse files
committed
clean up
1 parent 3abb802 commit 26d9115

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

jupyter_server/services/events/handlers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def on_close(self):
7171
self.event_logger.remove_listener(listener=self.event_listener)
7272

7373

74-
def validate_model(data: dict[str, Any], schema: jupyter_events.schema.EventSchema) -> None:
74+
def validate_model(
75+
data: dict[str, Any], registry: jupyter_events.schema_registry.SchemaRegistry
76+
) -> None:
7577
"""Validates for required fields in the JSON request body and verifies that
7678
a registered schema/version exists"""
7779
required_keys = {"schema_id", "version", "data"}
@@ -81,9 +83,7 @@ def validate_model(data: dict[str, Any], schema: jupyter_events.schema.EventSche
8183
raise Exception(message)
8284
schema_id = cast(str, data.get("schema_id"))
8385
version = cast(int, data.get("version"))
84-
if schema is None:
85-
message = f"Unregistered schema: `{schema_id}`"
86-
raise Exception(message)
86+
schema = registry.get(schema_id)
8787
if schema.version != version:
8888
message = f"Unregistered version: `{version}` for `{schema_id}`"
8989
raise Exception(message)
@@ -121,10 +121,9 @@ async def post(self):
121121
raise web.HTTPError(400, "No JSON data provided")
122122

123123
try:
124-
schema = self.event_logger.schemas.get(cast(str, payload.get("schema_id")))
125-
validate_model(payload, schema)
124+
validate_model(payload, self.event_logger.schemas)
126125
self.event_logger.emit(
127-
schema_id=schema.id,
126+
schema_id=cast(str, payload.get("schema_id")),
128127
data=cast("Dict[str, Any]", payload.get("data")),
129128
timestamp_override=get_timestamp(payload),
130129
)

tests/services/events/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,5 @@ async def test_post_event(jp_fetch, event_logger_sink, payload):
155155
async def test_post_event_400(jp_fetch, event_logger, payload):
156156
with pytest.raises(tornado.httpclient.HTTPClientError) as e:
157157
await jp_fetch("api", "events", method="POST", body=payload)
158+
158159
assert expected_http_error(e, 400)

0 commit comments

Comments
 (0)