Skip to content

Commit c92df50

Browse files
committed
Allow callable as config_loader_path
1 parent 31e9254 commit c92df50

File tree

2 files changed

+38
-29
lines changed

2 files changed

+38
-29
lines changed

djangosaml2/conf.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,55 +15,51 @@
1515

1616
import copy
1717
from importlib import import_module
18+
from typing import Callable, Optional, Union
1819

1920
from django.conf import settings
2021
from django.core.exceptions import ImproperlyConfigured
22+
from django.http import HttpRequest
23+
from django.utils.module_loading import import_string
2124
from saml2.config import SPConfig
2225

2326
from .utils import get_custom_setting
2427

2528

26-
def get_config_loader(path, request=None):
27-
i = path.rfind('.')
28-
module, attr = path[:i], path[i + 1:]
29+
def get_config_loader(path: str) -> Callable:
30+
""" Import the function at a given path and return it
31+
"""
2932
try:
30-
mod = import_module(module)
33+
config_loader = import_string(path)
3134
except ImportError as e:
32-
raise ImproperlyConfigured(
33-
'Error importing SAML config loader %s: "%s"' % (path, e))
34-
except ValueError as e:
35-
raise ImproperlyConfigured(
36-
'Error importing SAML config loader. Is SAML_CONFIG_LOADER '
37-
'a correctly string with a callable path?'
38-
)
39-
try:
40-
config_loader = getattr(mod, attr)
41-
except AttributeError:
42-
raise ImproperlyConfigured(
43-
'Module "%s" does not define a "%s" config loader' %
44-
(module, attr)
45-
)
35+
raise ImproperlyConfigured(f'Error importing SAML config loader {path}: "{e}"')
4636

47-
if not hasattr(config_loader, '__call__'):
48-
raise ImproperlyConfigured(
49-
"SAML config loader must be a callable object.")
37+
if not callable(config_loader):
38+
raise ImproperlyConfigured("SAML config loader must be a callable object.")
5039

5140
return config_loader
5241

5342

54-
def config_settings_loader(request=None):
55-
"""Utility function to load the pysaml2 configuration.
56-
57-
This is also the default config loader.
43+
def config_settings_loader(request: Optional[HttpRequest] = None) -> SPConfig:
44+
""" Utility function to load the pysaml2 configuration.
45+
The configuration can be modified based on the request being passed.
46+
This is the default config loader, which just loads the config from the settings.
5847
"""
5948
conf = SPConfig()
6049
conf.load(copy.deepcopy(settings.SAML_CONFIG))
6150
return conf
6251

6352

64-
def get_config(config_loader_path=None, request=None):
65-
config_loader_path = config_loader_path or get_custom_setting(
66-
'SAML_CONFIG_LOADER', 'djangosaml2.conf.config_settings_loader')
53+
def get_config(config_loader_path: Optional[Union[Callable, str]] = None, request: Optional[HttpRequest] = None) -> SPConfig:
54+
""" Load a config_loader function if necessary, and call that function with the request as argument.
55+
If the config_loader_path is a callable instead of a string, no importing is necessary and it will be used directly.
56+
Return the resulting SPConfig.
57+
"""
58+
config_loader_path = config_loader_path or get_custom_setting('SAML_CONFIG_LOADER', 'djangosaml2.conf.config_settings_loader')
59+
60+
if callable(config_loader_path):
61+
config_loader = config_loader_path
62+
else:
63+
config_loader = get_config_loader(config_loader_path)
6764

68-
config_loader = get_config_loader(config_loader_path)
6965
return config_loader(request)

djangosaml2/tests/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,12 @@ def test_config_loader(request):
636636
return config
637637

638638

639+
def test_config_loader_callable(request):
640+
config = SPConfig()
641+
config.load({'entityid': 'testentity_callable'})
642+
return config
643+
644+
639645
def test_config_loader_with_real_conf(request):
640646
config = SPConfig()
641647
config.load(conf.create_conf(sp_host='sp.example.com',
@@ -653,6 +659,13 @@ def test_custom_conf_loader(self):
653659

654660
self.assertEqual(conf.entityid, 'testentity')
655661

662+
def test_custom_conf_loader_callable(self):
663+
config_loader_path = test_config_loader_callable
664+
request = RequestFactory().get('/bar/foo')
665+
conf = get_config(config_loader_path, request)
666+
667+
self.assertEqual(conf.entityid, 'testentity_callable')
668+
656669
def test_custom_conf_loader_from_view(self):
657670
config_loader_path = 'djangosaml2.tests.test_config_loader_with_real_conf'
658671
request = RequestFactory().get('/login/')

0 commit comments

Comments
 (0)