|
1 | 1 | import base64
|
| 2 | +import logging |
2 | 3 | import binascii
|
3 |
| -from typing import Tuple, Optional, Sequence |
| 4 | +from typing import cast, List, Tuple, Optional, Sequence |
| 5 | +from typing_extensions import TypedDict |
4 | 6 |
|
| 7 | +import httpx |
5 | 8 | from starlette.requests import HTTPConnection
|
6 | 9 | from starlette.responses import Response
|
| 10 | +from starlette.applications import Starlette |
7 | 11 | from starlette.authentication import (
|
8 | 12 | SimpleUser,
|
9 | 13 | AuthCredentials,
|
10 | 14 | AuthenticationError,
|
11 | 15 | AuthenticationBackend,
|
12 | 16 | )
|
13 | 17 |
|
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
14 | 20 |
|
15 | 21 | class User(SimpleUser):
|
16 | 22 | def __init__(self, username: str, team: Optional[str]) -> None:
|
@@ -70,3 +76,81 @@ async def validate(self, username: str, password: str) -> ValidationResult:
|
70 | 76 | if not password:
|
71 | 77 | raise AuthenticationError("Must provide a password")
|
72 | 78 | return ['authenticated'], User(username, self.team)
|
| 79 | + |
| 80 | + |
| 81 | +NemesisUserInfo = TypedDict('NemesisUserInfo', { |
| 82 | + 'username': str, |
| 83 | + 'first_name': str, |
| 84 | + 'last_name': str, |
| 85 | + 'teams': List[str], |
| 86 | + 'is_blueshirt': bool, |
| 87 | + 'is_student': bool, |
| 88 | + 'is_team_leader': bool, |
| 89 | +}) |
| 90 | + |
| 91 | + |
| 92 | +class NemesisBackend(BasicAuthBackend): |
| 93 | + def __init__(self, _target: Optional[Starlette] = None, *, url: str) -> None: |
| 94 | + # Munge types to cope with httpx not supporting strict_optional but |
| 95 | + # actually being fine with given `None`. Note we expect only to pass |
| 96 | + # this value in tests, so need to cope with it being `None` most of the |
| 97 | + # time anyway. |
| 98 | + app = cast(Starlette, _target) |
| 99 | + self.client = httpx.AsyncClient(base_url=url, app=app) |
| 100 | + |
| 101 | + async def load_user(self, username: str, password: str) -> NemesisUserInfo: |
| 102 | + async with self.client as client: |
| 103 | + respone = await client.get( |
| 104 | + 'user/{}'.format(username), |
| 105 | + auth=(username, password), |
| 106 | + ) |
| 107 | + |
| 108 | + try: |
| 109 | + respone.raise_for_status() |
| 110 | + except httpx.HTTPError as e: |
| 111 | + if e.response.status_code != 403: |
| 112 | + logger.exception( |
| 113 | + "Failed to contact nemesis while trying to authenticate %r", |
| 114 | + username, |
| 115 | + ) |
| 116 | + raise AuthenticationError(e) from e |
| 117 | + |
| 118 | + return cast(NemesisUserInfo, respone.json()) |
| 119 | + |
| 120 | + def strip_team(self, team: str) -> str: |
| 121 | + # All teams from nemesis *should* start with this prefix... |
| 122 | + if team.startswith('team-'): |
| 123 | + return team[len('team-'):] |
| 124 | + return team |
| 125 | + |
| 126 | + def get_team(self, info: NemesisUserInfo) -> Optional[str]: |
| 127 | + teams = [self.strip_team(x) for x in info['teams']] |
| 128 | + |
| 129 | + if not teams: |
| 130 | + if info['is_student']: |
| 131 | + logger.warning("Competitor %r has no teams!", info['username']) |
| 132 | + return None |
| 133 | + |
| 134 | + team = teams[0] |
| 135 | + |
| 136 | + if len(teams) > 1: |
| 137 | + logger.warning( |
| 138 | + "User %r is in more than one team (%r), using %r", |
| 139 | + info['username'], |
| 140 | + teams, |
| 141 | + team, |
| 142 | + ) |
| 143 | + |
| 144 | + return team |
| 145 | + |
| 146 | + async def validate(self, username: str, password: str) -> ValidationResult: |
| 147 | + if not username: |
| 148 | + raise AuthenticationError("Must provide a username") |
| 149 | + if not password: |
| 150 | + raise AuthenticationError("Must provide a password") |
| 151 | + |
| 152 | + info = await self.load_user(username, password) |
| 153 | + |
| 154 | + team = self.get_team(info) |
| 155 | + |
| 156 | + return ['authenticated'], User(username, team) |
0 commit comments