Skip to content

Commit 5c4f482

Browse files
Session Improvements Async (#168)
* Adds cache decorator & cache or middleware calls * Remove prints * Expunge test * Adds session decorator and middleware cached options * Change to kwargs instead of args * Pool info * IS_DEV change * improvements * cache, token reset * session tests * exception fix * thread session * debug removal * optional ctx token * os session lookup * clean up * clean up * PR comments --------- Co-authored-by: JWittmeyer <[email protected]>
1 parent 42fa11b commit 5c4f482

File tree

8 files changed

+392
-34
lines changed

8 files changed

+392
-34
lines changed

business_objects/general.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,32 @@
33
from sqlalchemy.orm.session import make_transient as make_transient_original
44
from ..session import session, engine
55
from ..session import request_id_ctx_var
6-
from ..session import check_session_and_rollback as check_and_roll
6+
from ..session import check_session_and_rollback
77
from ..enums import Tablenames, try_parse_enum_value
88
import traceback
99
import datetime
1010
from .. import daemon
1111
from threading import Lock
1212
from sqlalchemy.dialects import postgresql
1313
from sqlalchemy.sql import Select
14-
14+
import os
1515

1616
__THREAD_LOCK = Lock()
1717

18+
IS_DEV = os.getenv("IS_DEV", "false").lower() in {"true", "1", "yes", "y"}
19+
1820
session_lookup = {}
1921

2022

2123
def get_ctx_token() -> Any:
2224
global session_lookup
2325
session_uuid = str(uuid.uuid4())
2426
session_id = request_id_ctx_var.set(session_uuid)
25-
26-
call_stack = "".join(traceback.format_stack()[-5:])
27+
if IS_DEV:
28+
# traces are usually long running and only useful for debugging
29+
call_stack = "".join(traceback.format_stack()[-5:])
30+
else:
31+
call_stack = "Activate dev mode to see call stack"
2732
with __THREAD_LOCK:
2833
session_lookup[session_uuid] = {
2934
"session_id": session_uuid,
@@ -46,15 +51,18 @@ def get_session_lookup(exclude_last_x_seconds: int = 5) -> Dict[str, Dict[str, A
4651
]
4752

4853

49-
def reset_ctx_token(
50-
ctx_token: Any,
51-
remove_db: Optional[bool] = False,
52-
) -> None:
54+
def reset_ctx_token(ctx_token: Any = None, remove_db: Optional[bool] = False) -> None:
5355
if remove_db:
5456
session.remove()
55-
session_uuid = ctx_token.var.get()
5657

57-
request_id_ctx_var.reset(ctx_token)
58+
session_uuid = request_id_ctx_var.get()
59+
if session_uuid is None:
60+
print("Session not found in context variable", flush=True)
61+
if ctx_token:
62+
request_id_ctx_var.reset(ctx_token)
63+
else:
64+
request_id_ctx_var.set(None)
65+
5866
global session_lookup
5967
with __THREAD_LOCK:
6068
if session_uuid in session_lookup:
@@ -106,9 +114,13 @@ def commit() -> None:
106114

107115

108116
def remove_and_refresh_session(
109-
session_token: Any, request_new: bool = False
117+
session_token: Any = None, request_new: bool = False
110118
) -> Union[Any, None]:
111-
check_and_roll()
119+
try:
120+
check_session_and_rollback()
121+
except Exception:
122+
print("Error: check_session_and_rollback() failed", flush=True)
123+
print(traceback.format_exc(), flush=True)
112124
reset_ctx_token(session_token, True)
113125
if request_new:
114126
return get_ctx_token()

business_objects/organization.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,27 @@
44
from submodules.model import enums
55

66

7-
from ..session import session
7+
from ..session import session, request_id_ctx_var
88
from ..models import Organization, Project, User
99
from ..business_objects import project, user, general
1010
from ..util import prevent_sql_injection
11+
from ..db_cache import TTLCacheDecorator, CacheEnum
1112

1213

1314
def get(id: str) -> Organization:
1415
return session.query(Organization).get(id)
1516

1617

18+
@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id")
19+
def get_org_cached(id: str) -> Organization:
20+
o = get(id)
21+
if not o:
22+
return None
23+
general.expunge(o)
24+
general.make_transient(o)
25+
return o
26+
27+
1728
def get_by_name(name: str) -> Organization:
1829
return session.query(Organization).filter(Organization.name == name).first()
1930

business_objects/user.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Optional
66
from sqlalchemy import sql
77

8+
from ..db_cache import TTLCacheDecorator, CacheEnum
89

910
from ..util import prevent_sql_injection
1011

@@ -13,6 +14,17 @@ def get(user_id: str) -> User:
1314
return session.query(User).get(user_id)
1415

1516

17+
@TTLCacheDecorator(CacheEnum.USER, 5, "user_id")
18+
def get_user_cached(user_id: str) -> User:
19+
user = get(user_id)
20+
if not user:
21+
return None
22+
23+
general.expunge(user)
24+
general.make_transient(user)
25+
return user
26+
27+
1628
def get_by_id_list(user_ids: List[str]) -> List[User]:
1729
return session.query(User).filter(User.id.in_(user_ids)).all()
1830

cognition_objects/project.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from typing import List, Optional, Dict, Any, Iterable
22
from ..business_objects import general, team_resource, user
33
from ..cognition_objects import consumption_log, consumption_summary
4-
from ..session import session
4+
from ..session import session, request_id_ctx_var
55
from ..models import CognitionProject, TeamMember, TeamResource
66
from .. import enums
77
from datetime import datetime
88
from ..util import prevent_sql_injection
99
from sqlalchemy.orm.attributes import flag_modified
1010
from copy import deepcopy
11+
from ..db_cache import TTLCacheDecorator, CacheEnum
1112

1213

1314
def get(project_id: str) -> CognitionProject:
@@ -18,6 +19,16 @@ def get(project_id: str) -> CognitionProject:
1819
)
1920

2021

22+
@TTLCacheDecorator(CacheEnum.PROJECT, 5, "project_id")
23+
def get_cached(project_id: str) -> CognitionProject:
24+
p = get(project_id)
25+
if not p:
26+
return None
27+
general.expunge(p)
28+
general.make_transient(p)
29+
return p
30+
31+
2132
def get_org_id(project_id: str) -> str:
2233
if p := get(project_id):
2334
return str(p.organization_id)
@@ -42,6 +53,16 @@ def get_by_user(project_id: str, user_id: str) -> CognitionProject:
4253
)
4354

4455

56+
@TTLCacheDecorator(CacheEnum.PROJECT, 5, "project_id", "user_id")
57+
def get_by_user_cached(project_id: str, user_id: str) -> CognitionProject:
58+
p = get_by_user(project_id, user_id)
59+
if not p:
60+
return None
61+
general.expunge(p)
62+
general.make_transient(p)
63+
return p
64+
65+
4566
def get_all(org_id: str, order_by_name: bool = False) -> List[CognitionProject]:
4667
query = session.query(CognitionProject).filter(
4768
CognitionProject.organization_id == org_id

daemon.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import threading
22
from submodules.model.business_objects import general
3-
from contextvars import ContextVar
43
import traceback
54

6-
thread_session_token = ContextVar("token", default=None)
7-
85

96
def run_without_db_token(target, *args, **kwargs):
107
"""
@@ -27,35 +24,22 @@ def run_with_db_token(target, *args, **kwargs):
2724

2825
# this is a workaround to set the token in the actual thread context
2926
def wrapper():
30-
token = general.get_ctx_token()
31-
thread_session_token.set(token)
27+
general.get_ctx_token()
3228
try:
3329
target(*args, **kwargs)
3430
except Exception:
3531
print("=== Exception in thread ===", flush=True)
3632
print(traceback.format_exc(), flush=True)
3733
print("===========================", flush=True)
3834
finally:
39-
reset_session_token_in_thread(False)
35+
general.remove_and_refresh_session()
4036

4137
threading.Thread(
4238
target=wrapper,
4339
daemon=True,
4440
).start()
4541

4642

47-
def reset_session_token_in_thread(request_new: bool = True):
48-
token = thread_session_token.get()
49-
if not token:
50-
# shouldn't happen if used with the run_with_session_token function
51-
# so we print where it was called from
52-
print(traceback.format_stack())
53-
raise ValueError("No token set in thread context")
54-
new_token = general.remove_and_refresh_session(token, request_new)
55-
if new_token:
56-
thread_session_token.set(new_token)
57-
58-
5943
def prepare_thread(target, *args, **kwargs) -> threading.Thread:
6044
return threading.Thread(
6145
target=target,

0 commit comments

Comments
 (0)