Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 100 additions & 112 deletions lib/crewai/src/crewai/cli/shared/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,119 +3,72 @@
import os
from pathlib import Path
import sys
from typing import BinaryIO, cast
import tempfile
from typing import Final, Literal, cast

from cryptography.fernet import Fernet


if sys.platform == "win32":
import msvcrt
else:
import fcntl
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44


class TokenManager:
"""Manages encrypted token storage."""

def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
"""Initialize the TokenManager.

:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
Args:
file_path: The file path to store encrypted tokens.
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)

@staticmethod
def _acquire_lock(file_handle: BinaryIO) -> None:
"""
Acquire an exclusive lock on a file handle.

Args:
file_handle: Open file handle to lock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)

@staticmethod
def _release_lock(file_handle: BinaryIO) -> None:
"""
Release the lock on a file handle.

Args:
file_handle: Open file handle to unlock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)

def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key with file locking to prevent race conditions.
"""Get or create the encryption key.

Returns:
The encryption key.
The encryption key as bytes.
"""
key_filename = "secret.key"
storage_path = self.get_secure_storage_path()
key_filename: str = "secret.key"

key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key

lock_file_path = storage_path / f"{key_filename}.lock"

try:
lock_file_path.touch()

with open(lock_file_path, "r+b") as lock_file:
self._acquire_lock(lock_file)
try:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key

new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
finally:
try:
self._release_lock(lock_file)
except OSError:
pass
except OSError:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key

new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
new_key = Fernet.generate_key()
if self._atomic_create_secure_file(key_filename, new_key):
return new_key

key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key

raise RuntimeError("Failed to create or read encryption key")

def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
"""Save the access token and its expiration time.

:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
Args:
access_token: The access token to save.
expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
self._atomic_write_secure_file(self.file_path, encrypted_data)

def get_token(self) -> str | None:
"""
Get the access token if it is valid and not expired.
"""Get the access token if it is valid and not expired.

:return: The access token if valid and not expired, otherwise None.
Returns:
The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
encrypted_data = self._read_secure_file(self.file_path)
if encrypted_data is None:
return None

Expand All @@ -126,20 +79,18 @@ def get_token(self) -> str | None:
if expiration <= datetime.now():
return None

return cast(str | None, data["access_token"])
return cast(str | None, data.get("access_token"))

def clear_tokens(self) -> None:
"""
Clear the tokens.
"""
self.delete_secure_file(self.file_path)
"""Clear the stored tokens."""
self._delete_secure_file(self.file_path)

@staticmethod
def get_secure_storage_path() -> Path:
"""
Get the secure storage path based on the operating system.
def _get_secure_storage_path() -> Path:
"""Get the secure storage path based on the operating system.

:return: The secure storage path.
Returns:
The secure storage path.
"""
if sys.platform == "win32":
base_path = os.environ.get("LOCALAPPDATA")
Expand All @@ -155,44 +106,81 @@ def get_secure_storage_path() -> Path:

return storage_path

def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
"""Create a file only if it doesn't exist.

Args:
filename: The name of the file.
content: The content to write.

:param filename: The name of the file.
:param content: The content to save.
Returns:
True if file was created, False if it already exists.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename

with open(file_path, "wb") as f:
f.write(content)

os.chmod(file_path, 0o600)
try:
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
try:
os.write(fd, content)
finally:
os.close(fd)
return True
except FileExistsError:
return False

def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
"""Write content to a secure file.

def read_secure_file(self, filename: str) -> bytes | None:
Args:
filename: The name of the file.
content: The content to write.
"""
Read the content of a secure file.
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename

fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
fd_closed = False
try:
os.write(fd, content)
os.close(fd)
fd_closed = True
os.chmod(temp_path, 0o600)
os.replace(temp_path, file_path)
except Exception:
if not fd_closed:
os.close(fd)
if os.path.exists(temp_path):
os.unlink(temp_path)
raise

def _read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.

Args:
filename: The name of the file.

:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
Returns:
The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename

if not file_path.exists():
try:
with open(file_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None

with open(file_path, "rb") as f:
return f.read()

def delete_secure_file(self, filename: str) -> None:
"""
Delete the secure file.
def _delete_secure_file(self, filename: str) -> None:
"""Delete a secure file.

:param filename: The name of the file.
Args:
filename: The name of the file.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
if file_path.exists():
file_path.unlink(missing_ok=True)
try:
file_path.unlink()
except FileNotFoundError:
pass
Loading
Loading