Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions jupyter_server/extension/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from logging import Logger
from typing import TYPE_CHECKING, Any, cast

from jinja2 import Template
from jinja2.exceptions import TemplateNotFound

from jupyter_server.base.handlers import FileFindHandler
Expand All @@ -21,13 +22,14 @@ class ExtensionHandlerJinjaMixin:
template rendering.
"""

def get_template(self, name: str) -> str:
def get_template(self, name: str) -> Template:
"""Return the jinja template object for a given name"""
try:
env = f"{self.name}_jinja2_env" # type:ignore[attr-defined]
return cast(str, self.settings[env].get_template(name)) # type:ignore[attr-defined]
template = cast(Template, self.settings[env].get_template(name)) # type:ignore[attr-defined]
return template
except TemplateNotFound:
return cast(str, super().get_template(name)) # type:ignore[misc]
return cast(Template, super().get_template(name)) # type:ignore[misc]


class ExtensionHandlerMixin:
Expand Down Expand Up @@ -81,6 +83,20 @@ def server_config(self) -> Config:
def base_url(self) -> str:
return cast(str, self.settings.get("base_url", "/"))

def render_template(self, name: str, **ns) -> str:
"""Override render template to handle static_paths

If render_template is called with a template from the base environment
(e.g. default error pages)
make sure our extension-specific static_url is _not_ used.
"""
template = cast(Template, self.get_template(name)) # type:ignore[attr-defined]
ns.update(self.template_namespace) # type:ignore[attr-defined]
if template.environment is self.settings["jinja2_env"]:
# default template environment, use default static_url
ns["static_url"] = super().static_url # type:ignore[misc]
return cast(str, template.render(**ns))

@property
def static_url_prefix(self) -> str:
return self.extensionapp.static_url_prefix
Expand Down
6 changes: 5 additions & 1 deletion tests/extension/mockextensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
to load in various tests.
"""

from .app import MockExtensionApp
from .app import MockExtensionApp, MockExtensionNoTemplateApp


# Function that makes these extensions discoverable
Expand All @@ -13,6 +13,10 @@ def _jupyter_server_extension_points():
"module": "tests.extension.mockextensions.app",
"app": MockExtensionApp,
},
{
"module": "tests.extension.mockextensions.app",
"app": MockExtensionNoTemplateApp,
},
{"module": "tests.extension.mockextensions.mock1"},
{"module": "tests.extension.mockextensions.mock2"},
{"module": "tests.extension.mockextensions.mock3"},
Expand Down
27 changes: 26 additions & 1 deletion tests/extension/mockextensions/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jupyter_events import EventLogger
from jupyter_events.schema_registry import SchemaRegistryException
from tornado import web
from traitlets import List, Unicode

from jupyter_server.base.handlers import JupyterHandler
Expand Down Expand Up @@ -44,14 +45,24 @@ def get(self):
self.write(self.render_template("index.html"))


class MockExtensionErrorHandler(ExtensionHandlerMixin, JupyterHandler):
def get(self):
raise web.HTTPError(418)


class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp):
name = "mockextension"
template_paths: List[str] = List().tag(config=True) # type:ignore[assignment]
static_paths = [STATIC_PATH] # type:ignore[assignment]
mock_trait = Unicode("mock trait", config=True)
loaded = False

serverapp_config = {"jpserver_extensions": {"tests.extension.mockextensions.mock1": True}}
serverapp_config = {
"jpserver_extensions": {
"tests.extension.mockextensions.mock1": True,
"tests.extension.mockextensions.app.mockextension_notemplate": True,
}
}

@staticmethod
def get_extension_package():
Expand All @@ -69,6 +80,20 @@ def initialize_settings(self):
def initialize_handlers(self):
self.handlers.append(("/mock", MockExtensionHandler))
self.handlers.append(("/mock_template", MockExtensionTemplateHandler))
self.handlers.append(("/mock_error_template", MockExtensionErrorHandler))
self.loaded = True


class MockExtensionNoTemplateApp(ExtensionApp):
name = "mockextension_notemplate"
loaded = False

@staticmethod
def get_extension_package():
return "tests.extension.mockextensions"

def initialize_handlers(self):
self.handlers.append(("/mock_error_notemplate", MockExtensionErrorHandler))
self.loaded = True


Expand Down
4 changes: 3 additions & 1 deletion tests/extension/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,14 @@ async def _stop(*args):
"Shutting down 2 extensions",
"jupyter_server_terminals | extension app 'jupyter_server_terminals' stopping",
f"{extension_name} | extension app 'mockextension' stopping",
f"{extension_name} | extension app 'mockextension_notemplate' stopping",
"jupyter_server_terminals | extension app 'jupyter_server_terminals' stopped",
f"{extension_name} | extension app 'mockextension' stopped",
f"{extension_name} | extension app 'mockextension_notemplate' stopped",
}

# check the shutdown method was called twice
assert calls == 2
assert calls == 3


async def test_events(jp_serverapp, jp_fetch):
Expand Down
69 changes: 69 additions & 0 deletions tests/extension/test_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from html.parser import HTMLParser

import pytest
from tornado.httpclient import HTTPClientError


@pytest.fixture
Expand Down Expand Up @@ -118,3 +121,69 @@ async def test_base_url(jp_fetch, jp_server_config, jp_base_url):
assert r.code == 200
body = r.body.decode()
assert "mock static content" in body


class StylesheetFinder(HTMLParser):
"""Minimal HTML parser to find iframe.src attr"""

def __init__(self):
super().__init__()
self.stylesheets = []
self.body_chunks = []
self.in_head = False
self.in_body = False
self.in_script = False

def handle_starttag(self, tag, attrs):
tag = tag.lower()
if tag == "head":
self.in_head = True
elif tag == "body":
self.in_body = True
elif tag == "script":
self.in_script = True
elif self.in_head and tag.lower() == "link":
attr_dict = dict(attrs)
if attr_dict.get("rel", "").lower() == "stylesheet":
self.stylesheets.append(attr_dict["href"])

def handle_endtag(self, tag):
if tag == "head":
self.in_head = False
if tag == "body":
self.in_body = False
if tag == "script":
self.in_script = False

def handle_data(self, data):
if self.in_body and not self.in_script:
data = data.strip()
if data:
self.body_chunks.append(data)


def find_stylesheets_body(html):
"""Find the href= attr of stylesheets

and body text of an HTML document

stylesheets are used to test static_url prefix
"""
finder = StylesheetFinder()
finder.feed(html)
return (finder.stylesheets, "\n".join(finder.body_chunks))


@pytest.mark.parametrize("error_url", ["mock_error_template", "mock_error_notemplate"])
async def test_error_render(jp_fetch, jp_serverapp, jp_base_url, error_url):
with pytest.raises(HTTPClientError) as e:
await jp_fetch(error_url, method="GET")
r = e.value.response
assert r.code == 418
assert r.headers["Content-Type"] == "text/html"
html = r.body.decode("utf8")
stylesheets, body = find_stylesheets_body(html)
static_prefix = f"{jp_base_url}static/"
assert stylesheets
assert all(stylesheet.startswith(static_prefix) for stylesheet in stylesheets)
assert str(r.code) in body
Loading