|
15 | 15 |
|
16 | 16 | import copy
|
17 | 17 | from importlib import import_module
|
| 18 | +from typing import Callable, Optional, Union |
18 | 19 |
|
19 | 20 | from django.conf import settings
|
20 | 21 | from django.core.exceptions import ImproperlyConfigured
|
| 22 | +from django.http import HttpRequest |
| 23 | +from django.utils.module_loading import import_string |
21 | 24 | from saml2.config import SPConfig
|
22 | 25 |
|
23 | 26 | from .utils import get_custom_setting
|
24 | 27 |
|
25 | 28 |
|
26 |
| -def get_config_loader(path, request=None): |
27 |
| - i = path.rfind('.') |
28 |
| - module, attr = path[:i], path[i + 1:] |
| 29 | +def get_config_loader(path: str) -> Callable: |
| 30 | + """ Import the function at a given path and return it |
| 31 | + """ |
29 | 32 | try:
|
30 |
| - mod = import_module(module) |
| 33 | + config_loader = import_string(path) |
31 | 34 | except ImportError as e:
|
32 |
| - raise ImproperlyConfigured( |
33 |
| - 'Error importing SAML config loader %s: "%s"' % (path, e)) |
34 |
| - except ValueError as e: |
35 |
| - raise ImproperlyConfigured( |
36 |
| - 'Error importing SAML config loader. Is SAML_CONFIG_LOADER ' |
37 |
| - 'a correctly string with a callable path?' |
38 |
| - ) |
39 |
| - try: |
40 |
| - config_loader = getattr(mod, attr) |
41 |
| - except AttributeError: |
42 |
| - raise ImproperlyConfigured( |
43 |
| - 'Module "%s" does not define a "%s" config loader' % |
44 |
| - (module, attr) |
45 |
| - ) |
| 35 | + raise ImproperlyConfigured(f'Error importing SAML config loader {path}: "{e}"') |
46 | 36 |
|
47 |
| - if not hasattr(config_loader, '__call__'): |
48 |
| - raise ImproperlyConfigured( |
49 |
| - "SAML config loader must be a callable object.") |
| 37 | + if not callable(config_loader): |
| 38 | + raise ImproperlyConfigured("SAML config loader must be a callable object.") |
50 | 39 |
|
51 | 40 | return config_loader
|
52 | 41 |
|
53 | 42 |
|
54 |
| -def config_settings_loader(request=None): |
55 |
| - """Utility function to load the pysaml2 configuration. |
56 |
| -
|
57 |
| - This is also the default config loader. |
| 43 | +def config_settings_loader(request: Optional[HttpRequest] = None) -> SPConfig: |
| 44 | + """ Utility function to load the pysaml2 configuration. |
| 45 | + The configuration can be modified based on the request being passed. |
| 46 | + This is the default config loader, which just loads the config from the settings. |
58 | 47 | """
|
59 | 48 | conf = SPConfig()
|
60 | 49 | conf.load(copy.deepcopy(settings.SAML_CONFIG))
|
61 | 50 | return conf
|
62 | 51 |
|
63 | 52 |
|
64 |
| -def get_config(config_loader_path=None, request=None): |
65 |
| - config_loader_path = config_loader_path or get_custom_setting( |
66 |
| - 'SAML_CONFIG_LOADER', 'djangosaml2.conf.config_settings_loader') |
| 53 | +def get_config(config_loader_path: Optional[Union[Callable, str]] = None, request: Optional[HttpRequest] = None) -> SPConfig: |
| 54 | + """ Load a config_loader function if necessary, and call that function with the request as argument. |
| 55 | + If the config_loader_path is a callable instead of a string, no importing is necessary and it will be used directly. |
| 56 | + Return the resulting SPConfig. |
| 57 | + """ |
| 58 | + config_loader_path = config_loader_path or get_custom_setting('SAML_CONFIG_LOADER', 'djangosaml2.conf.config_settings_loader') |
| 59 | + |
| 60 | + if callable(config_loader_path): |
| 61 | + config_loader = config_loader_path |
| 62 | + else: |
| 63 | + config_loader = get_config_loader(config_loader_path) |
67 | 64 |
|
68 |
| - config_loader = get_config_loader(config_loader_path) |
69 | 65 | return config_loader(request)
|
0 commit comments