Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions src/dao/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,55 @@ def remove(self, ids: list[str]) -> None:
"""
cursor.execute(q, (ids,))

def get_by_creator(self, creator: str) -> list[Message]:
with self.pool.connection() as conn, conn.cursor() as cur:
q = """
SELECT
message.id,
message.content,
message.creator,
message.role,
message.opts,
message.root,
message.created,
message.deleted,
message.parent,
message.template,
message.logprobs,
message.completion,
message.final,
message.original,
message.private,
message.model_type,
message.finish_reason,
message.harmful,
message.model_id,
message.model_host,
message.expiration_time,
message.file_urls,
label.id,
label.message,
label.rating,
label.creator,
label.comment,
label.created,
label.deleted
FROM
message
LEFT JOIN
label
ON
label.message = message.id
WHERE
(message.creator = %(creator)s OR %(creator)s IS NULL)
"""

rows = cur.execute(q, { "creator": creator }).fetchall()

msg_list = list(map(Message.from_row, rows))

return msg_list

# TODO: allow listing non-final messages
def get_list(
self,
Expand Down Expand Up @@ -803,18 +852,26 @@ def get_list(
)

def migrate_messages_to_new_user(self, previous_user_id: str, new_user_id: str):

params = {
"new_user_id": new_user_id,
"previous_user_id": previous_user_id,
}

with self.pool.connection() as conn:
q = """
UPDATE message
ql = """
UPDATE label
SET creator = %(new_user_id)s
WHERE creator = %(previous_user_id)s
"""


qm = """
UPDATE message
SET creator = %(new_user_id)s, expiration_time = NULL, private = false
WHERE creator = %(previous_user_id)s
"""

with conn.cursor() as cur:
return cur.execute(
query=q,
params={
"new_user_id": new_user_id,
"previous_user_id": previous_user_id,
},
).rowcount
cur.execute( query=ql, params=params )

return cur.execute( query=qm, params=params ).rowcount
85 changes: 78 additions & 7 deletions src/message/GoogleCloudStorage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import UTC, datetime
import re
from time import time_ns

Expand All @@ -6,6 +7,9 @@

from src.config import get_config

# GOOGLE CLOUD STORAGE doesn't accept extreme datetime values like 3000 AD as custom time
# For whoever sees this code in 2100 AD, please update the value!!!
GCS_MAX_DATETIME_LIMIT = datetime(2100, 10, 31, tzinfo=UTC)

class GoogleCloudStorage:
client: Client
Expand All @@ -15,11 +19,18 @@ def __init__(self, bucket_name=get_config.cfg.google_cloud_services.storage_buck
self.client = Client()
self.bucket = self.client.bucket(bucket_name)

def upload_content(self, filename: str, content: bytes | str, content_type: str = "text/plain"):
def upload_content(self, filename: str, content: bytes | str, content_type: str = "text/plain", is_anonymous: bool = False):
start_ns = time_ns()

blob = self.bucket.blob(filename)
blob.upload_from_string(data=content, content_type=content_type)
blob.make_public()

# We're using the file's custom time to have GCS automatically delete files associated with anonymous msgs
if is_anonymous:
blob.custom_time = datetime.now(UTC)
blob.patch()

end_ns = time_ns()

current_app.logger.info({
Expand Down Expand Up @@ -51,16 +62,76 @@ def delete_file(self, filename: str):
})

def delete_multiple_files_by_url(self, file_urls: list[str]):
start_ns = time_ns()

file_names = [re.sub(f"{self.client.api_endpoint}/{self.bucket.name}/", "", file_url) for file_url in file_urls]
self.bucket.delete_blobs(file_names)

found_blobs = []
for name in file_names:
blob = self.bucket.get_blob(blob_name=name)
if (blob is not None):
found_blobs.append(blob)

blob_names = [blob.name for blob in found_blobs]

try:
self.bucket.delete_blobs(found_blobs)
except Exception as e:
current_app.logger.exception(
f"Failed to delete {','.join(blob_names)} from the bucket:{self.bucket.name} on GoogleCloudStorage",
repr(e),
)

end_ns = time_ns()

current_app.logger.info({
"service": "GoogleCloudStorage",
"action": "batch_delete",
"filename": ','.join(blob_names),
"duration_ms": (end_ns - start_ns) / 1_000_000,
})

def update_file_deletion_time(self, filename: str, new_time: datetime):
if (new_time > GCS_MAX_DATETIME_LIMIT):
current_app.logger.info(
f"The new datetime for {filename} is over GoogleCloudStorage limit"
)
raise Exception

start_ns = time_ns()
try:
blob = self.bucket.get_blob(blob_name=filename)
if blob is None:
current_app.logger.error(
f"Cannot find {filename} in the bucket:{self.bucket.name} on GoogleCloudStorage",
)

def get_file_link(self, filename: str):
blob = self.bucket.get_blob(blob_name=filename)
if blob is None:
raise Exception

blob.custom_time = new_time
blob.patch()

except Exception as e:
current_app.logger.error(
f"Cannot find {filename} in the bucket:{self.bucket.name} on GoogleCloudStorage",
f"Failed to update the metadata of {filename} in the bucket:{self.bucket.name} on GoogleCloudStorage",
repr(e),
)

return None

return blob.public_url
end_ns = time_ns()

current_app.logger.info({
"service": "GoogleCloudStorage",
"action": "update_file_deletion_time",
"filename": filename,
"duration_ms": (end_ns - start_ns) / 1_000_000,
})


def migrate_anonymous_file(self, filename: str):
current_app.logger.info(
f"Migrating {filename} from anonymous to normal in the bucket:{self.bucket.name} on GoogleCloudStorage",
)
# GCS doesn't allow unsetting custom time, instead we're setting it to the furthest time possible
self.update_file_deletion_time(filename, GCS_MAX_DATETIME_LIMIT)
7 changes: 5 additions & 2 deletions src/message/create_message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def upload_request_files(
message_id: str,
storage_client: GoogleCloudStorage,
root_message_id: str,
is_anonymous: bool = False,
) -> list[str] | None:
if files is None or len(files) == 0:
return None
Expand All @@ -125,12 +126,13 @@ def upload_request_files(
filename = f"{root_message_id}/{message_id}-{i}{file_extension}"

if file.content_type is None:
file_url = storage_client.upload_content(filename=filename, content=file.stream.read())
file_url = storage_client.upload_content(filename=filename, content=file.stream.read(), is_anonymous=is_anonymous)
else:
file_url = storage_client.upload_content(
filename=filename,
content=file.stream.read(),
content_type=file.content_type,
is_anonymous=is_anonymous
)

# since we read from the file we need to rewind it so the next consumer can read it
Expand Down Expand Up @@ -201,7 +203,7 @@ def stream_new_message(
if is_content_safe is False or is_image_safe is False:
raise exceptions.BadRequest(description="inappropriate_prompt")

# We currently want anonymous users' messages to expire after 1 day
# We currently want anonymous users' messages to expire after 1 days
message_expiration_time = datetime.now(UTC) + timedelta(days=1) if agent.is_anonymous_user else None

is_msg_harmful = None if is_content_safe is None or is_image_safe is None else False
Expand Down Expand Up @@ -285,6 +287,7 @@ def stream_new_message(
message_id=msg.id,
storage_client=storage_client,
root_message_id=message_chain[0].id,
is_anonymous=agent.is_anonymous_user
)

chain: list[InferenceEngineMessage] = [
Expand Down
6 changes: 5 additions & 1 deletion src/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
migrate_user_from_anonymous_user,
upsert_user,
)
from src.message.GoogleCloudStorage import GoogleCloudStorage


class UserBlueprint(Blueprint):
dbc: db.Client
storage_client: GoogleCloudStorage

def __init__(self, dbc: db.Client):
def __init__(self, dbc: db.Client, storage_client: GoogleCloudStorage):
super().__init__("user", __name__)
self.dbc = dbc
self.storage_client = storage_client

self.get("/whoami")(self.whoami)
self.put("/user")(self.upsert_user)
Expand Down Expand Up @@ -66,6 +69,7 @@ def migrate_from_anonymous_user(self):

migration_result = migrate_user_from_anonymous_user(
dbc=self.dbc,
storage_client=self.storage_client,
anonymous_user_id=migration_request.anonymous_user_id,
new_user_id=migration_request.new_user_id,
)
Expand Down
14 changes: 13 additions & 1 deletion src/user/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.api_interface import APIInterface
from src.dao.user import User
from src.hubspot_service import create_contact
from src.message.GoogleCloudStorage import GoogleCloudStorage


class UpsertUserRequest(APIInterface):
Expand Down Expand Up @@ -60,7 +61,7 @@ class MigrateFromAnonymousUserResponse(APIInterface):
messages_updated_count: int = Field()


def migrate_user_from_anonymous_user(dbc: db.Client, anonymous_user_id: str, new_user_id: str):
def migrate_user_from_anonymous_user(dbc: db.Client, storage_client: GoogleCloudStorage, anonymous_user_id: str, new_user_id: str):
# migrate tos
previous_user = dbc.user.get_by_client(anonymous_user_id)
new_user = dbc.user.get_by_client(new_user_id)
Expand All @@ -84,6 +85,17 @@ def migrate_user_from_anonymous_user(dbc: db.Client, anonymous_user_id: str, new
elif previous_user is None and new_user is not None:
updated_user = new_user


msgs_to_be_migrated = dbc.message.get_by_creator(creator=anonymous_user_id)

for index, msg in enumerate(msgs_to_be_migrated):
# 1. migrate anonyous files on Google Cloud
for url in msg.file_urls or []:
filename = url.split('/')[-1]
whole_name = f"{msg.root}/{filename}"
Comment on lines +94 to +95
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably use Python's URL parsing here. We can also get the full path from the URL!

Maybe something like this?

Suggested change
filename = url.split('/')[-1]
whole_name = f"{msg.root}/{filename}"
whole_name = urlparse(url).path

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The url includes the bucket name. I don't think we want it here..

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, right, because we're doing this within the context of a bucket. This is fine then!

storage_client.migrate_anonymous_file(whole_name)

# 2. Remove expiration time, set private to false, update messages and labels with new user id
updated_messages_count = dbc.message.migrate_messages_to_new_user(
previous_user_id=anonymous_user_id, new_user_id=new_user_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, dbc: db.Client, storage_client: GoogleCloudStorage):
blueprint=create_v3_message_blueprint(dbc, storage_client=storage_client),
url_prefix="/message",
)
self.register_blueprint(blueprint=UserBlueprint(dbc=dbc))
self.register_blueprint(blueprint=UserBlueprint(dbc=dbc, storage_client=storage_client))
self.register_blueprint(blueprint=attribution_blueprint, url_prefix="/attribution")

def prompts(self):
Expand Down