Skip to content

Commit beef712

Browse files
fix: ensure token store file ops do not deadlock
* fix: ensure token store file ops do not deadlock * chore: update test method reference
1 parent 6125b86 commit beef712

File tree

3 files changed

+332
-216
lines changed

3 files changed

+332
-216
lines changed

lib/crewai/src/crewai/cli/shared/token_manager.py

Lines changed: 100 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -3,119 +3,72 @@
33
import os
44
from pathlib import Path
55
import sys
6-
from typing import BinaryIO, cast
6+
import tempfile
7+
from typing import Final, Literal, cast
78

89
from cryptography.fernet import Fernet
910

1011

11-
if sys.platform == "win32":
12-
import msvcrt
13-
else:
14-
import fcntl
12+
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
1513

1614

1715
class TokenManager:
16+
"""Manages encrypted token storage."""
17+
1818
def __init__(self, file_path: str = "tokens.enc") -> None:
19-
"""
20-
Initialize the TokenManager class.
19+
"""Initialize the TokenManager.
2120
22-
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
21+
Args:
22+
file_path: The file path to store encrypted tokens.
2323
"""
2424
self.file_path = file_path
2525
self.key = self._get_or_create_key()
2626
self.fernet = Fernet(self.key)
2727

28-
@staticmethod
29-
def _acquire_lock(file_handle: BinaryIO) -> None:
30-
"""
31-
Acquire an exclusive lock on a file handle.
32-
33-
Args:
34-
file_handle: Open file handle to lock.
35-
"""
36-
if sys.platform == "win32":
37-
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
38-
else:
39-
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)
40-
41-
@staticmethod
42-
def _release_lock(file_handle: BinaryIO) -> None:
43-
"""
44-
Release the lock on a file handle.
45-
46-
Args:
47-
file_handle: Open file handle to unlock.
48-
"""
49-
if sys.platform == "win32":
50-
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
51-
else:
52-
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
53-
5428
def _get_or_create_key(self) -> bytes:
55-
"""
56-
Get or create the encryption key with file locking to prevent race conditions.
29+
"""Get or create the encryption key.
5730
5831
Returns:
59-
The encryption key.
32+
The encryption key as bytes.
6033
"""
61-
key_filename = "secret.key"
62-
storage_path = self.get_secure_storage_path()
34+
key_filename: str = "secret.key"
6335

64-
key = self.read_secure_file(key_filename)
65-
if key is not None and len(key) == 44:
36+
key = self._read_secure_file(key_filename)
37+
if key is not None and len(key) == _FERNET_KEY_LENGTH:
6638
return key
6739

68-
lock_file_path = storage_path / f"{key_filename}.lock"
69-
70-
try:
71-
lock_file_path.touch()
72-
73-
with open(lock_file_path, "r+b") as lock_file:
74-
self._acquire_lock(lock_file)
75-
try:
76-
key = self.read_secure_file(key_filename)
77-
if key is not None and len(key) == 44:
78-
return key
79-
80-
new_key = Fernet.generate_key()
81-
self.save_secure_file(key_filename, new_key)
82-
return new_key
83-
finally:
84-
try:
85-
self._release_lock(lock_file)
86-
except OSError:
87-
pass
88-
except OSError:
89-
key = self.read_secure_file(key_filename)
90-
if key is not None and len(key) == 44:
91-
return key
92-
93-
new_key = Fernet.generate_key()
94-
self.save_secure_file(key_filename, new_key)
40+
new_key = Fernet.generate_key()
41+
if self._atomic_create_secure_file(key_filename, new_key):
9542
return new_key
9643

44+
key = self._read_secure_file(key_filename)
45+
if key is not None and len(key) == _FERNET_KEY_LENGTH:
46+
return key
47+
48+
raise RuntimeError("Failed to create or read encryption key")
49+
9750
def save_tokens(self, access_token: str, expires_at: int) -> None:
98-
"""
99-
Save the access token and its expiration time.
51+
"""Save the access token and its expiration time.
10052
101-
:param access_token: The access token to save.
102-
:param expires_at: The UNIX timestamp of the expiration time.
53+
Args:
54+
access_token: The access token to save.
55+
expires_at: The UNIX timestamp of the expiration time.
10356
"""
10457
expiration_time = datetime.fromtimestamp(expires_at)
10558
data = {
10659
"access_token": access_token,
10760
"expiration": expiration_time.isoformat(),
10861
}
10962
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
110-
self.save_secure_file(self.file_path, encrypted_data)
63+
self._atomic_write_secure_file(self.file_path, encrypted_data)
11164

11265
def get_token(self) -> str | None:
113-
"""
114-
Get the access token if it is valid and not expired.
66+
"""Get the access token if it is valid and not expired.
11567
116-
:return: The access token if valid and not expired, otherwise None.
68+
Returns:
69+
The access token if valid and not expired, otherwise None.
11770
"""
118-
encrypted_data = self.read_secure_file(self.file_path)
71+
encrypted_data = self._read_secure_file(self.file_path)
11972
if encrypted_data is None:
12073
return None
12174

@@ -126,20 +79,18 @@ def get_token(self) -> str | None:
12679
if expiration <= datetime.now():
12780
return None
12881

129-
return cast(str | None, data["access_token"])
82+
return cast(str | None, data.get("access_token"))
13083

13184
def clear_tokens(self) -> None:
132-
"""
133-
Clear the tokens.
134-
"""
135-
self.delete_secure_file(self.file_path)
85+
"""Clear the stored tokens."""
86+
self._delete_secure_file(self.file_path)
13687

13788
@staticmethod
138-
def get_secure_storage_path() -> Path:
139-
"""
140-
Get the secure storage path based on the operating system.
89+
def _get_secure_storage_path() -> Path:
90+
"""Get the secure storage path based on the operating system.
14191
142-
:return: The secure storage path.
92+
Returns:
93+
The secure storage path.
14394
"""
14495
if sys.platform == "win32":
14596
base_path = os.environ.get("LOCALAPPDATA")
@@ -155,44 +106,81 @@ def get_secure_storage_path() -> Path:
155106

156107
return storage_path
157108

158-
def save_secure_file(self, filename: str, content: bytes) -> None:
159-
"""
160-
Save the content to a secure file.
109+
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
110+
"""Create a file only if it doesn't exist.
111+
112+
Args:
113+
filename: The name of the file.
114+
content: The content to write.
161115
162-
:param filename: The name of the file.
163-
:param content: The content to save.
116+
Returns:
117+
True if file was created, False if it already exists.
164118
"""
165-
storage_path = self.get_secure_storage_path()
119+
storage_path = self._get_secure_storage_path()
166120
file_path = storage_path / filename
167121

168-
with open(file_path, "wb") as f:
169-
f.write(content)
170-
171-
os.chmod(file_path, 0o600)
122+
try:
123+
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
124+
try:
125+
os.write(fd, content)
126+
finally:
127+
os.close(fd)
128+
return True
129+
except FileExistsError:
130+
return False
131+
132+
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
133+
"""Write content to a secure file.
172134
173-
def read_secure_file(self, filename: str) -> bytes | None:
135+
Args:
136+
filename: The name of the file.
137+
content: The content to write.
174138
"""
175-
Read the content of a secure file.
139+
storage_path = self._get_secure_storage_path()
140+
file_path = storage_path / filename
141+
142+
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
143+
fd_closed = False
144+
try:
145+
os.write(fd, content)
146+
os.close(fd)
147+
fd_closed = True
148+
os.chmod(temp_path, 0o600)
149+
os.replace(temp_path, file_path)
150+
except Exception:
151+
if not fd_closed:
152+
os.close(fd)
153+
if os.path.exists(temp_path):
154+
os.unlink(temp_path)
155+
raise
156+
157+
def _read_secure_file(self, filename: str) -> bytes | None:
158+
"""Read the content of a secure file.
159+
160+
Args:
161+
filename: The name of the file.
176162
177-
:param filename: The name of the file.
178-
:return: The content of the file if it exists, otherwise None.
163+
Returns:
164+
The content of the file if it exists, otherwise None.
179165
"""
180-
storage_path = self.get_secure_storage_path()
166+
storage_path = self._get_secure_storage_path()
181167
file_path = storage_path / filename
182168

183-
if not file_path.exists():
169+
try:
170+
with open(file_path, "rb") as f:
171+
return f.read()
172+
except FileNotFoundError:
184173
return None
185174

186-
with open(file_path, "rb") as f:
187-
return f.read()
188-
189-
def delete_secure_file(self, filename: str) -> None:
190-
"""
191-
Delete the secure file.
175+
def _delete_secure_file(self, filename: str) -> None:
176+
"""Delete a secure file.
192177
193-
:param filename: The name of the file.
178+
Args:
179+
filename: The name of the file.
194180
"""
195-
storage_path = self.get_secure_storage_path()
181+
storage_path = self._get_secure_storage_path()
196182
file_path = storage_path / filename
197-
if file_path.exists():
198-
file_path.unlink(missing_ok=True)
183+
try:
184+
file_path.unlink()
185+
except FileNotFoundError:
186+
pass

0 commit comments

Comments
 (0)