|
1 | 1 | import datetime |
2 | 2 | import logging |
3 | | -import uuid |
4 | 3 | from sqlalchemy import create_engine |
5 | 4 |
|
6 | 5 | from sqlalchemy.orm import sessionmaker |
| 6 | +from cryptography.fernet import Fernet |
| 7 | +import base64 |
7 | 8 |
|
8 | | -from src.app.db.models import Auth, Base, Athlete |
9 | | -from stravalib.model import SummaryAthlete |
| 9 | +from src.app.db.models import Auth, Base, User |
10 | 10 |
|
11 | 11 | logging.basicConfig(level=logging.INFO) |
12 | 12 | logger = logging.getLogger(__name__) |
13 | 13 |
|
14 | 14 |
|
| 15 | +def encrypt_token(token: str, key: bytes) -> str: |
| 16 | + """Encrypts an access token.""" |
| 17 | + f = Fernet(key) |
| 18 | + encrypted_token = f.encrypt(token.encode()) |
| 19 | + return base64.urlsafe_b64encode(encrypted_token).decode() |
| 20 | + |
| 21 | + |
| 22 | +def decrypt_token(encrypted_token: str, key: bytes) -> str: |
| 23 | + """Decrypts an access token.""" |
| 24 | + f = Fernet(key) |
| 25 | + encrypted_token_bytes = base64.urlsafe_b64decode(encrypted_token) |
| 26 | + decrypted_token = f.decrypt(encrypted_token_bytes).decode() |
| 27 | + return decrypted_token |
| 28 | + |
| 29 | + |
15 | 30 | class Database: |
16 | | - def __init__(self, connection_string: str): |
| 31 | + def __init__(self, connection_string: str, encryption_key: bytes): |
17 | 32 | engine = create_engine(connection_string) |
18 | 33 | Base.metadata.create_all(engine) # create the tables. |
19 | | - Session = sessionmaker(bind=engine) |
20 | | - self.session = Session() |
21 | | - |
22 | | - def add_athlete(self, athlete: SummaryAthlete): |
23 | | - found_athlete = ( |
24 | | - self.session.query(Athlete).filter(Athlete.athlete_id == athlete.id).first() |
25 | | - ) |
26 | | - if not found_athlete: |
27 | | - id = str(uuid.uuid4()) |
28 | | - kwargs = { |
29 | | - k: v |
30 | | - for k, v in athlete.__dict__.items() |
31 | | - if k in Athlete.__table__.columns.keys() |
32 | | - } |
33 | | - |
34 | | - kwargs["created_at"] = datetime.datetime.now() |
35 | | - kwargs["updated_at"] = datetime.datetime.now() |
36 | | - kwargs["athlete_id"] = athlete.id |
37 | | - |
38 | | - kwargs["uuid"] = id |
39 | | - |
40 | | - athlete_model = Athlete(**kwargs) |
41 | | - self.session.add(athlete_model) |
42 | | - self.session.commit() |
43 | | - logger.info(f"Added athlete {athlete.id} to the database") |
44 | | - |
45 | | - return id |
46 | | - else: |
47 | | - logger.info(f"Athlete with id {athlete.id} already exists in the database") |
48 | | - return found_athlete.uuid |
49 | | - |
50 | | - def get_athlete(self, uuid: str) -> Athlete: |
51 | | - return self.session.query(Athlete).filter(Athlete.uuid == uuid).first() |
52 | | - |
53 | | - def add_auth( |
54 | | - self, |
55 | | - athlete_id: int, |
56 | | - access_token: str, |
57 | | - refresh_token: str, |
58 | | - expires_at: int, |
59 | | - scope: str, |
60 | | - ): |
61 | | - if self.session.query(Auth).filter(Auth.athlete_id == athlete_id).first(): |
62 | | - self.session.query(Auth).filter(Auth.athlete_id == athlete_id).update( |
63 | | - { |
64 | | - "access_token": access_token, |
65 | | - "refresh_token": refresh_token, |
66 | | - "expires_at": expires_at, |
67 | | - "scope": scope, |
68 | | - "updated_at": datetime.datetime.now(), |
69 | | - } |
| 34 | + self.Session = sessionmaker(bind=engine) |
| 35 | + |
| 36 | + self.encryption_key = encryption_key |
| 37 | + |
| 38 | + def add_user(self, user: User): |
| 39 | + with self.Session() as session: |
| 40 | + existing_user = ( |
| 41 | + session.query(User).filter(User.athlete_id == user.athlete_id).first() |
70 | 42 | ) |
71 | | - self.session.commit() |
72 | | - logger.info(f"Updated auth for athlete {athlete_id}") |
73 | | - |
74 | | - else: |
75 | | - auth = Auth( |
76 | | - uuid=str(uuid.uuid4()), |
77 | | - athlete_id=athlete_id, |
78 | | - access_token=access_token, |
79 | | - refresh_token=refresh_token, |
80 | | - expires_at=expires_at, |
81 | | - scope=scope, |
82 | | - created_at=datetime.datetime.now(), |
83 | | - updated_at=datetime.datetime.now(), |
| 43 | + if not existing_user: |
| 44 | + session.add(user) |
| 45 | + session.commit() |
| 46 | + logger.info(f"Added user {user.uuid} to the database") |
| 47 | + |
| 48 | + return str(user.uuid) |
| 49 | + |
| 50 | + logger.info( |
| 51 | + f"User with id {existing_user.uuid} already exists in the database" |
84 | 52 | ) |
85 | | - self.session.add(auth) |
86 | | - self.session.commit() |
87 | | - logger.info(f"Added auth for athlete {athlete_id}") |
| 53 | + return str(existing_user.uuid) |
| 54 | + |
| 55 | + def get_user(self, uuid: str) -> User: |
| 56 | + with self.Session() as session: |
| 57 | + return session.query(User).filter(User.uuid == uuid).first() |
88 | 58 |
|
89 | | - def get_auth(self, athlete_id: int) -> Auth: |
90 | | - return self.session.query(Auth).filter(Auth.athlete_id == athlete_id).first() |
| 59 | + def add_auth(self, auth: Auth): |
| 60 | + # encrypt tokens |
| 61 | + auth.access_token = encrypt_token(auth.access_token, self.encryption_key) |
| 62 | + auth.refresh_token = encrypt_token(auth.refresh_token, self.encryption_key) |
| 63 | + |
| 64 | + with self.Session() as session: |
| 65 | + existing_auth = ( |
| 66 | + session.query(Auth).filter(Auth.athlete_id == auth.athlete_id).first() |
| 67 | + ) |
| 68 | + if existing_auth: |
| 69 | + existing_auth.access_token = auth.access_token |
| 70 | + existing_auth.refresh_token = auth.refresh_token |
| 71 | + existing_auth.expires_at = auth.expires_at |
| 72 | + existing_auth.scope = auth.scope |
| 73 | + existing_auth.updated_at = datetime.datetime.now() |
| 74 | + |
| 75 | + session.commit() |
| 76 | + logger.info(f"Updated auth for athlete {auth.athlete_id}") |
| 77 | + |
| 78 | + else: |
| 79 | + session.add(auth) |
| 80 | + session.commit() |
| 81 | + logger.info(f"Added auth for athlete {auth.athlete_id}") |
| 82 | + |
| 83 | + def get_auth_by_athlete_id(self, athlete_id: int) -> Auth: |
| 84 | + with self.Session() as session: |
| 85 | + auth = session.query(Auth).filter(Auth.athlete_id == athlete_id).first() |
| 86 | + auth.access_token = decrypt_token(auth.access_token, self.encryption_key) |
| 87 | + auth.refresh_token = decrypt_token(auth.refresh_token, self.encryption_key) |
| 88 | + return auth |
91 | 89 |
|
92 | 90 | # def add_activity(self, activity: SummaryActivity): |
93 | 91 | # activity_dict = activity.model_dump() |
|
0 commit comments