Skip to content

Commit 987f6ac

Browse files
KIRA009abrichr
andauthored
fix(db): Database access refactor (#676)
* feat: Remove global sessions, and introduce read only sessions for cases where no writing is required * refactor: Rename db to session * feat: Raise exceptions if commit/write/delete is attempted on a read-only session * feat: Add tests for the read only session * chore: lint using flake8 * rename test_database -> db_engine --------- Co-authored-by: Richard Abrich <[email protected]>
1 parent b438c9c commit 987f6ac

File tree

24 files changed

+483
-327
lines changed

24 files changed

+483
-327
lines changed

experiments/imagesimilarity.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Callable
44
import time
55

6-
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
6+
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
77
from PIL import Image, ImageOps
88
from skimage.metrics import structural_similarity as ssim
99
from sklearn.manifold import MDS
@@ -12,8 +12,7 @@
1212
import matplotlib.pyplot as plt
1313
import numpy as np
1414

15-
from openadapt.db import crud
16-
15+
from openadapt.session import crud
1716

1817
SHOW_SSIM = False
1918

@@ -290,7 +289,8 @@ def display_distance_matrix_with_images(
290289

291290
def main() -> None:
292291
"""Main function to process images and display similarity metrics."""
293-
recording = crud.get_latest_recording()
292+
session = crud.get_new_session(read_only=True)
293+
recording = crud.get_latest_recording(session)
294294
action_events = recording.processed_action_events
295295
images = [action_event.screenshot.cropped_image for action_event in action_events]
296296

openadapt/app/cards.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from openadapt.app.objects.local_file_picker import LocalFilePicker
1414
from openadapt.app.util import get_scrub, set_dark, set_scrub, sync_switch
15-
from openadapt.db.crud import new_session
1615
from openadapt.record import record
1716

1817

@@ -146,7 +145,6 @@ def quick_record(
146145
) -> None:
147146
"""Run a recording session."""
148147
global record_proc
149-
new_session()
150148
task_description = task_description or datetime.now().strftime("%d/%m/%Y %H:%M:%S")
151149
record_proc.start(
152150
record,
@@ -204,7 +202,6 @@ def begin() -> None:
204202
ui.notify(
205203
f"Recording {name}... Press CTRL + C in terminal window to cancel",
206204
)
207-
new_session()
208205
global record_proc
209206
record_proc.start(
210207
record,

openadapt/app/dashboard/api/recordings.py

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,27 @@ def attach_routes(self) -> APIRouter:
3232
@staticmethod
3333
def get_recordings() -> dict[str, list[Recording]]:
3434
"""Get all recordings."""
35-
session = crud.get_new_session()
35+
session = crud.get_new_session(read_only=True)
3636
recordings = crud.get_all_recordings(session)
3737
return {"recordings": recordings}
3838

3939
@staticmethod
4040
def get_scrubbed_recordings() -> dict[str, list[Recording]]:
4141
"""Get all scrubbed recordings."""
42-
session = crud.get_new_session()
42+
session = crud.get_new_session(read_only=True)
4343
recordings = crud.get_all_scrubbed_recordings(session)
4444
return {"recordings": recordings}
4545

4646
@staticmethod
47-
async def start_recording() -> dict[str, str]:
47+
def start_recording() -> dict[str, str | int]:
4848
"""Start a recording session."""
49-
await crud.acquire_db_lock()
5049
cards.quick_record()
51-
return {"message": "Recording started"}
50+
return {"message": "Recording started", "status": 200}
5251

5352
@staticmethod
5453
def stop_recording() -> dict[str, str]:
5554
"""Stop a recording session."""
5655
cards.stop_record()
57-
crud.release_db_lock()
5856
return {"message": "Recording stopped"}
5957

6058
@staticmethod
@@ -69,48 +67,45 @@ def recording_detail_route(self) -> None:
6967
async def get_recording_detail(websocket: WebSocket, recording_id: int) -> None:
7068
"""Get a specific recording and its action events."""
7169
await websocket.accept()
72-
session = crud.get_new_session()
73-
with session:
74-
recording = crud.get_recording_by_id(recording_id, session)
75-
76-
await websocket.send_json(
77-
{"type": "recording", "value": recording.asdict()}
78-
)
79-
80-
action_events = get_events(recording, session=session)
81-
82-
await websocket.send_json(
83-
{"type": "num_events", "value": len(action_events)}
84-
)
85-
86-
def convert_to_str(event_dict: dict) -> dict:
87-
"""Convert the keys to strings."""
88-
if "key" in event_dict:
89-
event_dict["key"] = str(event_dict["key"])
90-
if "canonical_key" in event_dict:
91-
event_dict["canonical_key"] = str(event_dict["canonical_key"])
92-
if "reducer_names" in event_dict:
93-
event_dict["reducer_names"] = list(event_dict["reducer_names"])
94-
if "children" in event_dict:
95-
for child_event in event_dict["children"]:
96-
convert_to_str(child_event)
97-
98-
for action_event in action_events:
99-
event_dict = row2dict(action_event)
100-
try:
101-
image = display_event(action_event)
102-
width, height = image.size
103-
image = image2utf8(image)
104-
except Exception:
105-
logger.info("Failed to display event")
106-
image = None
107-
width, height = 0, 0
108-
event_dict["screenshot"] = image
109-
event_dict["dimensions"] = {"width": width, "height": height}
110-
111-
convert_to_str(event_dict)
112-
await websocket.send_json(
113-
{"type": "action_event", "value": event_dict}
114-
)
115-
116-
await websocket.close()
70+
session = crud.get_new_session(read_only=True)
71+
recording = crud.get_recording_by_id(session, recording_id)
72+
73+
await websocket.send_json(
74+
{"type": "recording", "value": recording.asdict()}
75+
)
76+
77+
action_events = get_events(session, recording)
78+
79+
await websocket.send_json(
80+
{"type": "num_events", "value": len(action_events)}
81+
)
82+
83+
def convert_to_str(event_dict: dict) -> dict:
84+
"""Convert the keys to strings."""
85+
if "key" in event_dict:
86+
event_dict["key"] = str(event_dict["key"])
87+
if "canonical_key" in event_dict:
88+
event_dict["canonical_key"] = str(event_dict["canonical_key"])
89+
if "reducer_names" in event_dict:
90+
event_dict["reducer_names"] = list(event_dict["reducer_names"])
91+
if "children" in event_dict:
92+
for child_event in event_dict["children"]:
93+
convert_to_str(child_event)
94+
95+
for action_event in action_events:
96+
event_dict = row2dict(action_event)
97+
try:
98+
image = display_event(action_event)
99+
width, height = image.size
100+
image = image2utf8(image)
101+
except Exception:
102+
logger.info("Failed to display event")
103+
image = None
104+
width, height = 0, 0
105+
event_dict["screenshot"] = image
106+
event_dict["dimensions"] = {"width": width, "height": height}
107+
108+
convert_to_str(event_dict)
109+
await websocket.send_json({"type": "action_event", "value": event_dict})
110+
111+
await websocket.close()

openadapt/app/dashboard/api/scrubbing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from fastapi.responses import StreamingResponse
99

1010
from openadapt.config import config
11-
from openadapt.db import crud
1211
from openadapt.privacy.providers import ScrubProvider
1312
from openadapt.scrub import get_scrubbing_process, scrub
1413

@@ -62,7 +61,6 @@ async def scrub_recording(recording_id: int, provider_id: str) -> dict[str, str]
6261
}
6362
if provider_id not in ScrubProvider.get_available_providers():
6463
return {"message": "Provider not supported", "status": "failed"}
65-
await crud.acquire_db_lock()
6664
scrub(recording_id, provider_id, release_lock=True)
6765
scrubbing_proc = get_scrubbing_process()
6866
while not scrubbing_proc.is_running():

openadapt/app/dashboard/components/Shell/Shell.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
'use client'
22

3-
import { AppShell, Burger, Image, Text } from '@mantine/core'
3+
import { AppShell, Box, Burger, Image, Text } from '@mantine/core'
44
import React from 'react'
55
import { Navbar } from '../Navbar'
66
import { useDisclosure } from '@mantine/hooks'
@@ -30,12 +30,12 @@ export const Shell = ({ children }: Props) => {
3030
hiddenFrom="sm"
3131
size="sm"
3232
/>
33-
<Text className="h-full flex items-center px-5 gap-x-2">
33+
<Box className="h-full flex items-center px-5 gap-x-2">
3434
<Image src={logo.src} alt="OpenAdapt" w={40} />
3535
<Text>
3636
OpenAdapt.AI
3737
</Text>
38-
</Text>
38+
</Box>
3939
</AppShell.Header>
4040

4141
<AppShell.Navbar>

openadapt/app/tray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def __init__(self) -> None:
7676

7777
self.app.setQuitOnLastWindowClosed(False)
7878

79+
# since the lock is a file, delete it when starting the app so that
80+
# new instances can start even if the previous one crashed
81+
crud.release_db_lock(raise_exception=False)
82+
7983
# currently required for pyqttoast
8084
# TODO: remove once https://github.com/niklashenning/pyqt-toast/issues/9
8185
# is addressed
@@ -379,7 +383,12 @@ def _delete(self, recording: Recording) -> None:
379383
"""
380384
dialog = ConfirmDeleteDialog(recording.task_description)
381385
if dialog.exec_():
382-
crud.delete_recording(recording.timestamp)
386+
if not crud.acquire_db_lock():
387+
self.show_toast("Failed to delete recording. Try again later.")
388+
return
389+
with crud.get_new_session(read_and_write=True) as session:
390+
crud.delete_recording(session, recording)
391+
crud.release_db_lock()
383392
self.show_toast("Recording deleted.")
384393
self.populate_menus()
385394

@@ -413,7 +422,8 @@ def populate_menu(self, menu: QMenu, action: Callable, action_type: str) -> None
413422
action (Callable): The function to call when the menu item is clicked.
414423
action_type (str): The type of action to perform ["visualize", "replay"]
415424
"""
416-
recordings = crud.get_all_recordings()
425+
session = crud.get_new_session(read_only=True)
426+
recordings = crud.get_all_recordings(session)
417427

418428
self.recording_actions[action_type] = []
419429

openadapt/app/visualize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import click
1717

1818
from openadapt.config import config
19-
from openadapt.db.crud import get_latest_recording, get_recording
19+
from openadapt.db import crud
2020
from openadapt.events import get_events
2121
from openadapt.utils import (
2222
EMPTY,
@@ -141,18 +141,19 @@ def main(timestamp: str) -> None:
141141
configure_logging(logger, LOG_LEVEL)
142142

143143
ui_dark = ui.dark_mode(config.VISUALIZE_DARK_MODE)
144+
session = crud.get_new_session(read_only=True)
144145

145146
if timestamp is None:
146-
recording = get_latest_recording()
147+
recording = crud.get_latest_recording(session)
147148
else:
148-
recording = get_recording(timestamp)
149+
recording = crud.get_recording(session, timestamp)
149150

150151
if SCRUB:
151152
scrub.scrub_text(recording.task_description)
152153
logger.debug(f"{recording=}")
153154

154155
meta = {}
155-
action_events = get_events(recording, process=PROCESS_EVENTS, meta=meta)
156+
action_events = get_events(session, recording, process=PROCESS_EVENTS, meta=meta)
156157
event_dicts = rows2dicts(action_events)
157158

158159
if SCRUB:

0 commit comments

Comments
 (0)