|
1 | 1 | # Copyright (c) Jupyter Development Team.
|
2 | 2 | # Distributed under the terms of the Modified BSD License.
|
| 3 | +import importlib |
3 | 4 | import io
|
4 | 5 | import json
|
5 | 6 | import logging
|
|
14 | 15 | import pytest
|
15 | 16 | import tornado
|
16 | 17 | from tornado.escape import url_escape
|
17 |
| -from traitlets.config import Config |
| 18 | +from tornado.httpclient import HTTPClientError |
| 19 | +from tornado.websocket import WebSocketHandler |
| 20 | +from traitlets.config import Config, re |
18 | 21 |
|
| 22 | +from jupyter_server.auth import Authorizer |
19 | 23 | from jupyter_server.extension import serverextension
|
20 |
| -from jupyter_server.serverapp import ServerApp |
| 24 | +from jupyter_server.serverapp import JUPYTER_SERVICE_HANDLERS, ServerApp |
21 | 25 | from jupyter_server.services.contents.filemanager import FileContentsManager
|
22 | 26 | from jupyter_server.services.contents.largefilemanager import LargeFileManager
|
23 | 27 | from jupyter_server.utils import url_path_join
|
@@ -494,3 +498,122 @@ async def _():
|
494 | 498 | pass
|
495 | 499 |
|
496 | 500 | return _
|
| 501 | + |
| 502 | + |
| 503 | +@pytest.fixture |
| 504 | +def send_request(jp_fetch, jp_ws_fetch): |
| 505 | + """Send to Jupyter Server and return response code.""" |
| 506 | + |
| 507 | + async def _(url, **fetch_kwargs): |
| 508 | + if url.endswith("channels") or "/websocket/" in url: |
| 509 | + fetch = jp_ws_fetch |
| 510 | + else: |
| 511 | + fetch = jp_fetch |
| 512 | + |
| 513 | + try: |
| 514 | + r = await fetch(url, **fetch_kwargs, allow_nonstandard_methods=True) |
| 515 | + code = r.code |
| 516 | + except HTTPClientError as err: |
| 517 | + code = err.code |
| 518 | + else: |
| 519 | + if fetch is jp_ws_fetch: |
| 520 | + r.close() |
| 521 | + |
| 522 | + return code |
| 523 | + |
| 524 | + return _ |
| 525 | + |
| 526 | + |
| 527 | +@pytest.fixture |
| 528 | +def jp_server_auth_core_resources(): |
| 529 | + modules = [] |
| 530 | + for mod_name in JUPYTER_SERVICE_HANDLERS.values(): |
| 531 | + if mod_name: |
| 532 | + modules.extend(mod_name) |
| 533 | + resource_map = {} |
| 534 | + for handler_module in modules: |
| 535 | + mod = importlib.import_module(handler_module) |
| 536 | + name = mod.AUTH_RESOURCE |
| 537 | + for handler in mod.default_handlers: |
| 538 | + url_regex = handler[0] |
| 539 | + resource_map[url_regex] = name |
| 540 | + return resource_map |
| 541 | + |
| 542 | + |
| 543 | +@pytest.fixture |
| 544 | +def jp_server_auth_resources(jp_server_auth_core_resources): |
| 545 | + return jp_server_auth_core_resources |
| 546 | + |
| 547 | + |
| 548 | +@pytest.fixture |
| 549 | +def jp_server_authorizer(jp_server_auth_resources): |
| 550 | + class _(Authorizer): |
| 551 | + |
| 552 | + # Set these class attributes from within a test |
| 553 | + # to verify that they match the arguments passed |
| 554 | + # by the REST API. |
| 555 | + permissions: dict = {} |
| 556 | + |
| 557 | + HTTP_METHOD_TO_AUTH_ACTION = { |
| 558 | + "GET": "read", |
| 559 | + "HEAD": "read", |
| 560 | + "OPTIONS": "read", |
| 561 | + "POST": "write", |
| 562 | + "PUT": "write", |
| 563 | + "PATCH": "write", |
| 564 | + "DELETE": "write", |
| 565 | + "WEBSOCKET": "execute", |
| 566 | + } |
| 567 | + |
| 568 | + def match_url_to_resource(self, url, regex_mapping=None): |
| 569 | + """Finds the JupyterHandler regex pattern that would |
| 570 | + match the given URL and returns the resource name (str) |
| 571 | + of that handler. |
| 572 | +
|
| 573 | + e.g. |
| 574 | + /api/contents/... returns "contents" |
| 575 | + """ |
| 576 | + if not regex_mapping: |
| 577 | + regex_mapping = jp_server_auth_resources |
| 578 | + for regex, auth_resource in regex_mapping.items(): |
| 579 | + pattern = re.compile(regex) |
| 580 | + if pattern.fullmatch(url): |
| 581 | + return auth_resource |
| 582 | + |
| 583 | + def normalize_url(self, path): |
| 584 | + """Drop the base URL and make sure path leads with a /""" |
| 585 | + base_url = self.parent.base_url |
| 586 | + # Remove base_url |
| 587 | + if path.startswith(base_url): |
| 588 | + path = path[len(base_url) :] |
| 589 | + # Make sure path starts with / |
| 590 | + if not path.startswith("/"): |
| 591 | + path = "/" + path |
| 592 | + return path |
| 593 | + |
| 594 | + def is_authorized(self, handler, user, action, resource): |
| 595 | + # Parse Request |
| 596 | + if isinstance(handler, WebSocketHandler): |
| 597 | + method = "WEBSOCKET" |
| 598 | + else: |
| 599 | + method = handler.request.method |
| 600 | + url = self.normalize_url(handler.request.path) |
| 601 | + |
| 602 | + # Map request parts to expected action and resource. |
| 603 | + expected_action = self.HTTP_METHOD_TO_AUTH_ACTION[method] |
| 604 | + expected_resource = self.match_url_to_resource(url) |
| 605 | + |
| 606 | + # Assert that authorization layer returns the |
| 607 | + # correct action + resource. |
| 608 | + assert action == expected_action |
| 609 | + assert resource == expected_resource |
| 610 | + |
| 611 | + # Now, actually apply the authorization layer. |
| 612 | + return all( |
| 613 | + [ |
| 614 | + action in self.permissions.get("actions", []), |
| 615 | + resource in self.permissions.get("resources", []), |
| 616 | + ] |
| 617 | + ) |
| 618 | + |
| 619 | + return _ |
0 commit comments