|
1 | 1 | import json |
2 | 2 | import logging |
| 3 | +from urllib.parse import parse_qsl, urlencode, urlparse |
3 | 4 |
|
4 | 5 | from django.contrib.auth.mixins import LoginRequiredMixin |
| 6 | +from django.contrib.auth.views import redirect_to_login |
5 | 7 | from django.http import HttpResponse |
| 8 | +from django.shortcuts import resolve_url |
6 | 9 | from django.utils import timezone |
7 | 10 | from django.utils.decorators import method_decorator |
8 | 11 | from django.views.decorators.csrf import csrf_exempt |
@@ -144,6 +147,10 @@ def get(self, request, *args, **kwargs): |
144 | 147 | # Application is not available at this time. |
145 | 148 | return self.error_response(error, application=None) |
146 | 149 |
|
| 150 | + prompt = request.GET.get("prompt") |
| 151 | + if prompt == "login": |
| 152 | + return self.handle_prompt_login() |
| 153 | + |
147 | 154 | all_scopes = get_scopes_backend().get_all_scopes() |
148 | 155 | kwargs["scopes_descriptions"] = [all_scopes[scope] for scope in scopes] |
149 | 156 | kwargs["scopes"] = scopes |
@@ -211,6 +218,32 @@ def get(self, request, *args, **kwargs): |
211 | 218 |
|
212 | 219 | return self.render_to_response(self.get_context_data(**kwargs)) |
213 | 220 |
|
| 221 | + def handle_prompt_login(self): |
| 222 | + path = self.request.build_absolute_uri() |
| 223 | + resolved_login_url = resolve_url(self.get_login_url()) |
| 224 | + |
| 225 | + # If the login url is the same scheme and net location then use the |
| 226 | + # path as the "next" url. |
| 227 | + login_scheme, login_netloc = urlparse(resolved_login_url)[:2] |
| 228 | + current_scheme, current_netloc = urlparse(path)[:2] |
| 229 | + if (not login_scheme or login_scheme == current_scheme) and ( |
| 230 | + not login_netloc or login_netloc == current_netloc |
| 231 | + ): |
| 232 | + path = self.request.get_full_path() |
| 233 | + |
| 234 | + parsed = urlparse(path) |
| 235 | + |
| 236 | + parsed_query = dict(parse_qsl(parsed.query)) |
| 237 | + parsed_query.pop("prompt") |
| 238 | + |
| 239 | + parsed = parsed._replace(query=urlencode(parsed_query)) |
| 240 | + |
| 241 | + return redirect_to_login( |
| 242 | + parsed.geturl(), |
| 243 | + resolved_login_url, |
| 244 | + self.get_redirect_field_name(), |
| 245 | + ) |
| 246 | + |
214 | 247 |
|
215 | 248 | @method_decorator(csrf_exempt, name="dispatch") |
216 | 249 | class TokenView(OAuthLibMixin, View): |
|
0 commit comments