Skip to content

Commit b4dd4f1

Browse files
authored
Delete files associated with anonymous msgs (#258)
1 parent 41fa6f7 commit b4dd4f1

File tree

6 files changed

+169
-22
lines changed

6 files changed

+169
-22
lines changed

src/dao/message.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,55 @@ def remove(self, ids: list[str]) -> None:
668668
"""
669669
cursor.execute(q, (ids,))
670670

671+
def get_by_creator(self, creator: str) -> list[Message]:
672+
with self.pool.connection() as conn, conn.cursor() as cur:
673+
q = """
674+
SELECT
675+
message.id,
676+
message.content,
677+
message.creator,
678+
message.role,
679+
message.opts,
680+
message.root,
681+
message.created,
682+
message.deleted,
683+
message.parent,
684+
message.template,
685+
message.logprobs,
686+
message.completion,
687+
message.final,
688+
message.original,
689+
message.private,
690+
message.model_type,
691+
message.finish_reason,
692+
message.harmful,
693+
message.model_id,
694+
message.model_host,
695+
message.expiration_time,
696+
message.file_urls,
697+
label.id,
698+
label.message,
699+
label.rating,
700+
label.creator,
701+
label.comment,
702+
label.created,
703+
label.deleted
704+
FROM
705+
message
706+
LEFT JOIN
707+
label
708+
ON
709+
label.message = message.id
710+
WHERE
711+
(message.creator = %(creator)s OR %(creator)s IS NULL)
712+
"""
713+
714+
rows = cur.execute(q, { "creator": creator }).fetchall()
715+
716+
msg_list = list(map(Message.from_row, rows))
717+
718+
return msg_list
719+
671720
# TODO: allow listing non-final messages
672721
def get_list(
673722
self,
@@ -803,18 +852,26 @@ def get_list(
803852
)
804853

805854
def migrate_messages_to_new_user(self, previous_user_id: str, new_user_id: str):
855+
856+
params = {
857+
"new_user_id": new_user_id,
858+
"previous_user_id": previous_user_id,
859+
}
860+
806861
with self.pool.connection() as conn:
807-
q = """
808-
UPDATE message
862+
ql = """
863+
UPDATE label
809864
SET creator = %(new_user_id)s
810865
WHERE creator = %(previous_user_id)s
811866
"""
812-
867+
868+
qm = """
869+
UPDATE message
870+
SET creator = %(new_user_id)s, expiration_time = NULL, private = false
871+
WHERE creator = %(previous_user_id)s
872+
"""
873+
813874
with conn.cursor() as cur:
814-
return cur.execute(
815-
query=q,
816-
params={
817-
"new_user_id": new_user_id,
818-
"previous_user_id": previous_user_id,
819-
},
820-
).rowcount
875+
cur.execute( query=ql, params=params )
876+
877+
return cur.execute( query=qm, params=params ).rowcount

src/message/GoogleCloudStorage.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import UTC, datetime
12
import re
23
from time import time_ns
34

@@ -6,6 +7,9 @@
67

78
from src.config import get_config
89

10+
# GOOGLE CLOUD STORAGE doesn't accept extreme datetime values like 3000 AD as custom time
11+
# For whoever sees this code in 2100 AD, please update the value!!!
12+
GCS_MAX_DATETIME_LIMIT = datetime(2100, 10, 31, tzinfo=UTC)
913

1014
class GoogleCloudStorage:
1115
client: Client
@@ -15,11 +19,18 @@ def __init__(self, bucket_name=get_config.cfg.google_cloud_services.storage_buck
1519
self.client = Client()
1620
self.bucket = self.client.bucket(bucket_name)
1721

18-
def upload_content(self, filename: str, content: bytes | str, content_type: str = "text/plain"):
22+
def upload_content(self, filename: str, content: bytes | str, content_type: str = "text/plain", is_anonymous: bool = False):
1923
start_ns = time_ns()
24+
2025
blob = self.bucket.blob(filename)
2126
blob.upload_from_string(data=content, content_type=content_type)
2227
blob.make_public()
28+
29+
# We're using the file's custom time to have GCS automatically delete files associated with anonymous msgs
30+
if is_anonymous:
31+
blob.custom_time = datetime.now(UTC)
32+
blob.patch()
33+
2334
end_ns = time_ns()
2435

2536
current_app.logger.info({
@@ -51,16 +62,76 @@ def delete_file(self, filename: str):
5162
})
5263

5364
def delete_multiple_files_by_url(self, file_urls: list[str]):
65+
start_ns = time_ns()
66+
5467
file_names = [re.sub(f"{self.client.api_endpoint}/{self.bucket.name}/", "", file_url) for file_url in file_urls]
55-
self.bucket.delete_blobs(file_names)
68+
69+
found_blobs = []
70+
for name in file_names:
71+
blob = self.bucket.get_blob(blob_name=name)
72+
if (blob is not None):
73+
found_blobs.append(blob)
74+
75+
blob_names = [blob.name for blob in found_blobs]
76+
77+
try:
78+
self.bucket.delete_blobs(found_blobs)
79+
except Exception as e:
80+
current_app.logger.exception(
81+
f"Failed to delete {','.join(blob_names)} from the bucket:{self.bucket.name} on GoogleCloudStorage",
82+
repr(e),
83+
)
84+
85+
end_ns = time_ns()
86+
87+
current_app.logger.info({
88+
"service": "GoogleCloudStorage",
89+
"action": "batch_delete",
90+
"filename": ','.join(blob_names),
91+
"duration_ms": (end_ns - start_ns) / 1_000_000,
92+
})
93+
94+
def update_file_deletion_time(self, filename: str, new_time: datetime):
95+
if (new_time > GCS_MAX_DATETIME_LIMIT):
96+
current_app.logger.info(
97+
f"The new datetime for {filename} is over GoogleCloudStorage limit"
98+
)
99+
raise Exception
100+
101+
start_ns = time_ns()
102+
try:
103+
blob = self.bucket.get_blob(blob_name=filename)
104+
if blob is None:
105+
current_app.logger.error(
106+
f"Cannot find {filename} in the bucket:{self.bucket.name} on GoogleCloudStorage",
107+
)
56108

57-
def get_file_link(self, filename: str):
58-
blob = self.bucket.get_blob(blob_name=filename)
59-
if blob is None:
109+
raise Exception
110+
111+
blob.custom_time = new_time
112+
blob.patch()
113+
114+
except Exception as e:
60115
current_app.logger.error(
61-
f"Cannot find {filename} in the bucket:{self.bucket.name} on GoogleCloudStorage",
116+
f"Failed to update the metadata of {filename} in the bucket:{self.bucket.name} on GoogleCloudStorage",
117+
repr(e),
62118
)
63119

64120
return None
65121

66-
return blob.public_url
122+
end_ns = time_ns()
123+
124+
current_app.logger.info({
125+
"service": "GoogleCloudStorage",
126+
"action": "update_file_deletion_time",
127+
"filename": filename,
128+
"duration_ms": (end_ns - start_ns) / 1_000_000,
129+
})
130+
131+
132+
def migrate_anonymous_file(self, filename: str):
133+
current_app.logger.info(
134+
f"Migrating {filename} from anonymous to normal in the bucket:{self.bucket.name} on GoogleCloudStorage",
135+
)
136+
# GCS doesn't allow unsetting custom time, instead we're setting it to the furthest time possible
137+
self.update_file_deletion_time(filename, GCS_MAX_DATETIME_LIMIT)

src/message/create_message_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def upload_request_files(
112112
message_id: str,
113113
storage_client: GoogleCloudStorage,
114114
root_message_id: str,
115+
is_anonymous: bool = False,
115116
) -> list[str] | None:
116117
if files is None or len(files) == 0:
117118
return None
@@ -125,12 +126,13 @@ def upload_request_files(
125126
filename = f"{root_message_id}/{message_id}-{i}{file_extension}"
126127

127128
if file.content_type is None:
128-
file_url = storage_client.upload_content(filename=filename, content=file.stream.read())
129+
file_url = storage_client.upload_content(filename=filename, content=file.stream.read(), is_anonymous=is_anonymous)
129130
else:
130131
file_url = storage_client.upload_content(
131132
filename=filename,
132133
content=file.stream.read(),
133134
content_type=file.content_type,
135+
is_anonymous=is_anonymous
134136
)
135137

136138
# since we read from the file we need to rewind it so the next consumer can read it
@@ -201,7 +203,7 @@ def stream_new_message(
201203
if is_content_safe is False or is_image_safe is False:
202204
raise exceptions.BadRequest(description="inappropriate_prompt")
203205

204-
# We currently want anonymous users' messages to expire after 1 day
206+
# We currently want anonymous users' messages to expire after 1 days
205207
message_expiration_time = datetime.now(UTC) + timedelta(days=1) if agent.is_anonymous_user else None
206208

207209
is_msg_harmful = None if is_content_safe is None or is_image_safe is None else False
@@ -285,6 +287,7 @@ def stream_new_message(
285287
message_id=msg.id,
286288
storage_client=storage_client,
287289
root_message_id=message_chain[0].id,
290+
is_anonymous=agent.is_anonymous_user
288291
)
289292

290293
chain: list[InferenceEngineMessage] = [

src/user/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
migrate_user_from_anonymous_user,
1212
upsert_user,
1313
)
14+
from src.message.GoogleCloudStorage import GoogleCloudStorage
1415

1516

1617
class UserBlueprint(Blueprint):
1718
dbc: db.Client
19+
storage_client: GoogleCloudStorage
1820

19-
def __init__(self, dbc: db.Client):
21+
def __init__(self, dbc: db.Client, storage_client: GoogleCloudStorage):
2022
super().__init__("user", __name__)
2123
self.dbc = dbc
24+
self.storage_client = storage_client
2225

2326
self.get("/whoami")(self.whoami)
2427
self.put("/user")(self.upsert_user)
@@ -66,6 +69,7 @@ def migrate_from_anonymous_user(self):
6669

6770
migration_result = migrate_user_from_anonymous_user(
6871
dbc=self.dbc,
72+
storage_client=self.storage_client,
6973
anonymous_user_id=migration_request.anonymous_user_id,
7074
new_user_id=migration_request.new_user_id,
7175
)

src/user/user_service.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from src.api_interface import APIInterface
99
from src.dao.user import User
1010
from src.hubspot_service import create_contact
11+
from src.message.GoogleCloudStorage import GoogleCloudStorage
1112

1213

1314
class UpsertUserRequest(APIInterface):
@@ -60,7 +61,7 @@ class MigrateFromAnonymousUserResponse(APIInterface):
6061
messages_updated_count: int = Field()
6162

6263

63-
def migrate_user_from_anonymous_user(dbc: db.Client, anonymous_user_id: str, new_user_id: str):
64+
def migrate_user_from_anonymous_user(dbc: db.Client, storage_client: GoogleCloudStorage, anonymous_user_id: str, new_user_id: str):
6465
# migrate tos
6566
previous_user = dbc.user.get_by_client(anonymous_user_id)
6667
new_user = dbc.user.get_by_client(new_user_id)
@@ -84,6 +85,17 @@ def migrate_user_from_anonymous_user(dbc: db.Client, anonymous_user_id: str, new
8485
elif previous_user is None and new_user is not None:
8586
updated_user = new_user
8687

88+
89+
msgs_to_be_migrated = dbc.message.get_by_creator(creator=anonymous_user_id)
90+
91+
for index, msg in enumerate(msgs_to_be_migrated):
92+
# 1. migrate anonyous files on Google Cloud
93+
for url in msg.file_urls or []:
94+
filename = url.split('/')[-1]
95+
whole_name = f"{msg.root}/{filename}"
96+
storage_client.migrate_anonymous_file(whole_name)
97+
98+
# 2. Remove expiration time, set private to false, update messages and labels with new user id
8799
updated_messages_count = dbc.message.migrate_messages_to_new_user(
88100
previous_user_id=anonymous_user_id, new_user_id=new_user_id
89101
)

src/v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, dbc: db.Client, storage_client: GoogleCloudStorage):
5757
blueprint=create_v3_message_blueprint(dbc, storage_client=storage_client),
5858
url_prefix="/message",
5959
)
60-
self.register_blueprint(blueprint=UserBlueprint(dbc=dbc))
60+
self.register_blueprint(blueprint=UserBlueprint(dbc=dbc, storage_client=storage_client))
6161
self.register_blueprint(blueprint=attribution_blueprint, url_prefix="/attribution")
6262

6363
def prompts(self):

0 commit comments

Comments
 (0)