Skip to content
/ kvmd Public

Commit 8fe7bf1

Browse files
committed
plugins.oauth.oauth2: add OAuth2 provider
1 parent 44753c2 commit 8fe7bf1

File tree

1 file changed

+194
-0
lines changed

1 file changed

+194
-0
lines changed

kvmd/plugins/oauth/oauth2.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# ========================================================================== #
2+
# #
3+
# KVMD - The main PiKVM daemon. #
4+
# #
5+
# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
6+
# 2024-2024 Markus Beckschulte (SLA/RWTH Aachen) #
7+
# This program is free software: you can redistribute it and/or modify #
8+
# it under the terms of the GNU General Public License as published by #
9+
# the Free Software Foundation, either version 3 of the License, or #
10+
# (at your option) any later version. #
11+
# #
12+
# This program is distributed in the hope that it will be useful, #
13+
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
14+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
15+
# GNU General Public License for more details. #
16+
# #
17+
# You should have received a copy of the GNU General Public License #
18+
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
19+
# #
20+
# ========================================================================== #
21+
22+
23+
from typing import Any
24+
from urllib.parse import urlencode
25+
import secrets
26+
import time
27+
28+
import aiohttp
29+
from aiohttp import ClientSession
30+
from yarl import URL
31+
32+
from ...validators.basic import valid_stripped_string_not_empty
33+
from ...validators.net import valid_url
34+
from ...yamlconf import Option
35+
from . import BaseOAuthProvider, OAuthError
36+
37+
38+
class Plugin(BaseOAuthProvider): # pylint: disable=too-many-instance-attributes
39+
def __init__( # pylint: disable=too-many-arguments
40+
self,
41+
client_id: str,
42+
client_secret: str,
43+
access_token_url: str,
44+
authorize_url: str,
45+
base_url: str,
46+
user_info_url: str,
47+
short_name: str,
48+
long_name: str,
49+
scope: str,
50+
username_attribute: str
51+
) -> None:
52+
super().__init__(long_name)
53+
self.__client_id = client_id
54+
self.__client_secret = client_secret
55+
self.__access_token_url: URL = URL(access_token_url)
56+
self.__authorize_url: URL = URL(authorize_url)
57+
self.__base_url: URL = URL(base_url)
58+
self.__user_info_url: URL = URL(user_info_url)
59+
self.__scope = scope
60+
self.__username_attribute = username_attribute
61+
self.__states: list[OAuthState] = []
62+
63+
@classmethod
64+
def get_plugin_options(cls) -> dict:
65+
return {
66+
"client_id": Option("", type=valid_stripped_string_not_empty),
67+
"client_secret": Option("", type=valid_stripped_string_not_empty),
68+
"access_token_url": Option("", type=valid_url),
69+
"authorize_url": Option("", type=valid_url),
70+
"base_url": Option("", type=valid_url),
71+
"user_info_url": Option("", type=valid_url),
72+
"short_name": Option("", type=valid_stripped_string_not_empty),
73+
"long_name": Option("", type=valid_stripped_string_not_empty),
74+
"scope": Option("", type=valid_stripped_string_not_empty),
75+
"username_attribute": Option("", type=valid_stripped_string_not_empty),
76+
}
77+
78+
def is_redirect_from_provider(self, request_query: dict) -> bool:
79+
return "code" in request_query # TODO
80+
81+
def get_authorize_url(self, redirect_url: URL, session: dict) -> str:
82+
"""
83+
Generates the Authorization-Code-Request
84+
@param redirect_url: the redirect URL the provider should redirect to after login
85+
@param session: the encrypted session
86+
@return: the authorization code request url
87+
"""
88+
params: dict[str, str] = {}
89+
params.update(
90+
{
91+
"client_id": self.__client_id,
92+
"response_type": "code",
93+
"scope": self.__scope,
94+
"access_type": "offline",
95+
"state": session['state'],
96+
"redirect_uri": redirect_url.human_repr(),
97+
}
98+
)
99+
ret = f"{self.__authorize_url}?{urlencode(params)}"
100+
return ret
101+
102+
def is_valid_session(self, oauth_session: dict) -> bool:
103+
"""
104+
Checks if the state provided in the oauth_session is valid.
105+
@param oauth_session: the session
106+
@return: True: session is valid
107+
"""
108+
if 'state' not in oauth_session:
109+
return False
110+
for stored_state in self.__states:
111+
if oauth_session['state'] == stored_state.get_value():
112+
if not stored_state.is_valid():
113+
self.__states.pop(oauth_session['state'])
114+
return False
115+
return True
116+
117+
async def get_user_info(
118+
self,
119+
oauth_session: dict,
120+
request_query: dict,
121+
redirect_url: URL
122+
) -> str:
123+
"""
124+
Returns the Username provided by the provider. Uses the authorization code to get an access_token.
125+
@param oauth_session: the session with state parameter
126+
@param request_query: the query as dict containing the authorization code
127+
@param redirect_url: the redirect_uri also used in get_authorize_url
128+
@return: Username
129+
"""
130+
if not self.is_valid_session(oauth_session):
131+
raise OAuthError("unknown or invalid state")
132+
133+
payload = {
134+
"grant_type": "authorization_code",
135+
"client_id": self.__client_id,
136+
"client_secret": self.__client_secret,
137+
"code": request_query['code'],
138+
"redirect_uri": str(redirect_url),
139+
"state": oauth_session['state']
140+
}
141+
headers = {"content-type": "application/x-www-form-urlencoded"}
142+
async with ClientSession() as session:
143+
try:
144+
async with session.post(self.__access_token_url, data=payload, headers=headers) as resp:
145+
token_data = await resp.json()
146+
if 'access_token' not in token_data:
147+
raise OAuthError(f"could not get access-token{str(token_data)}")
148+
access_token = token_data.get("access_token")
149+
except aiohttp.ClientConnectorError as error:
150+
raise OAuthError("could not connect to provider! error message: %s" % str(error))
151+
152+
headers = {
153+
"Cache-Control": "no-cache",
154+
"Authorization": f"Bearer {access_token}"
155+
}
156+
try:
157+
async with session.get(self.__user_info_url, headers=headers) as response:
158+
user_info = await response.json()
159+
return user_info.get(self.__username_attribute, "_oauth_user_")
160+
except aiohttp.ClientConnectorError as error:
161+
raise OAuthError("could not connect to provider! error message: %s" % str(error))
162+
163+
def register_new_session(self) -> dict:
164+
"""
165+
creates a new session with a new state
166+
@return: new session with state
167+
"""
168+
state = OAuthState()
169+
self.__states.append(state)
170+
return {'state': state.get_value()}
171+
172+
173+
class OAuthState:
174+
_TTL = 3600.0 # valid for one hour
175+
176+
def __init__(self) -> None:
177+
self.state = secrets.token_urlsafe(16)
178+
self.__created = time.time()
179+
180+
def __eq__(self, other: Any) -> bool:
181+
if isinstance(other, OAuthState):
182+
return self.state == other.state
183+
return False
184+
185+
def __getitem__(self, item: Any):
186+
if item == self.state:
187+
return self
188+
return None
189+
190+
def is_valid(self) -> bool:
191+
return (self.__created + self._TTL) > time.time()
192+
193+
def get_value(self) -> str:
194+
return self.state

0 commit comments

Comments
 (0)