Skip to content

Commit b9dac8d

Browse files
authored
[PERF] argilla server: Reduce general transaction time (#5609)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> - Remove unnecessary extra SAVEPOINTS - Move out of the commit all search index operations - ~~the refresh parameter for es can be set to `wait_for` instead of `true` since index operations are out of transactions. This will help with the es performance.~~ **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 66454e8 commit b9dac8d

File tree

10 files changed

+254
-269
lines changed

10 files changed

+254
-269
lines changed

argilla-server/pdm.lock

Lines changed: 59 additions & 62 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

argilla-server/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ dependencies = [
3030
"backoff>=1.11.1",
3131
# Database dependencies
3232
"alembic ~= 1.9.0",
33-
"SQLAlchemy == 2.0.31",
34-
"greenlet >= 2.0.0",
33+
"SQLAlchemy == 2.0.35",
34+
"greenlet ~= 3.1.0",
3535
# Async SQLite
3636
"aiosqlite == 0.20.0",
3737
# metrics

argilla-server/src/argilla_server/bulk/records_bulk.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,26 @@ def __init__(self, db: AsyncSession, search_engine: SearchEngine):
4848
async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCreate) -> RecordsBulk:
4949
await RecordsBulkCreateValidator.validate(self._db, bulk_create, dataset)
5050

51-
async with self._db.begin_nested():
52-
records = [
53-
Record(
54-
fields=jsonable_encoder(record_create.fields),
55-
metadata_=record_create.metadata,
56-
external_id=record_create.external_id,
57-
dataset_id=dataset.id,
58-
)
59-
for record_create in bulk_create.items
60-
]
61-
62-
self._db.add_all(records)
63-
await self._db.flush(records)
64-
65-
await self._upsert_records_relationships(records, bulk_create.items)
66-
await _preload_records_relationships_before_index(self._db, records)
67-
await distribution.unsafe_update_records_status(self._db, records)
68-
await self._search_engine.index_records(dataset, records)
51+
records = [
52+
Record(
53+
fields=jsonable_encoder(record_create.fields),
54+
metadata_=record_create.metadata,
55+
external_id=record_create.external_id,
56+
dataset_id=dataset.id,
57+
)
58+
for record_create in bulk_create.items
59+
]
60+
61+
self._db.add_all(records)
62+
await self._db.flush(records)
63+
await self._upsert_records_relationships(records, bulk_create.items)
64+
await distribution.unsafe_update_records_status(self._db, records)
6965

7066
await self._db.commit()
7167

68+
await _preload_records_relationships_before_index(self._db, records)
69+
await self._search_engine.index_records(dataset, records)
70+
7271
return RecordsBulk(items=records)
7372

7473
async def _upsert_records_relationships(self, records: List[Record], records_create: List[RecordCreate]) -> None:
@@ -139,7 +138,8 @@ async def _upsert_records_vectors(
139138
autocommit=False,
140139
)
141140

142-
def _metadata_is_set(self, record_create: RecordCreate) -> bool:
141+
@classmethod
142+
def _metadata_is_set(cls, record_create: RecordCreate) -> bool:
143143
return "metadata" in record_create.__fields_set__
144144

145145

@@ -151,32 +151,32 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp
151151
await RecordsBulkUpsertValidator.validate(bulk_upsert, dataset, found_records)
152152

153153
records = []
154-
async with self._db.begin_nested():
155-
for record_upsert in bulk_upsert.items:
156-
record = found_records.get(record_upsert.id) or found_records.get(record_upsert.external_id)
157-
if not record:
158-
record = Record(
159-
fields=jsonable_encoder(record_upsert.fields),
160-
metadata_=record_upsert.metadata,
161-
external_id=record_upsert.external_id,
162-
dataset_id=dataset.id,
163-
)
164-
elif self._metadata_is_set(record_upsert):
165-
record.metadata_ = record_upsert.metadata
166-
record.updated_at = datetime.utcnow()
167-
168-
records.append(record)
169-
170-
self._db.add_all(records)
171-
await self._db.flush(records)
172-
173-
await self._upsert_records_relationships(records, bulk_upsert.items)
174-
await _preload_records_relationships_before_index(self._db, records)
175-
await distribution.unsafe_update_records_status(self._db, records)
176-
await self._search_engine.index_records(dataset, records)
154+
155+
for record_upsert in bulk_upsert.items:
156+
record = found_records.get(record_upsert.id) or found_records.get(record_upsert.external_id)
157+
if not record:
158+
record = Record(
159+
fields=jsonable_encoder(record_upsert.fields),
160+
metadata_=record_upsert.metadata,
161+
external_id=record_upsert.external_id,
162+
dataset_id=dataset.id,
163+
)
164+
elif self._metadata_is_set(record_upsert):
165+
record.metadata_ = record_upsert.metadata
166+
record.updated_at = datetime.utcnow()
167+
168+
records.append(record)
169+
170+
self._db.add_all(records)
171+
await self._db.flush(records)
172+
await self._upsert_records_relationships(records, bulk_upsert.items)
173+
await distribution.unsafe_update_records_status(self._db, records)
177174

178175
await self._db.commit()
179176

177+
await _preload_records_relationships_before_index(self._db, records)
178+
await self._search_engine.index_records(dataset, records)
179+
180180
return RecordsBulkWithUpdateInfo(
181181
items=records,
182182
updated_item_ids=[record.id for record in found_records.values()],

argilla-server/src/argilla_server/contexts/accounts.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -110,29 +110,28 @@ async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List
110110
if await get_user_by_username(db, user_attrs["username"]) is not None:
111111
raise NotUniqueError(f"User username `{user_attrs['username']}` is not unique")
112112

113-
async with db.begin_nested():
114-
user = await User.create(
115-
db,
116-
first_name=user_attrs["first_name"],
117-
last_name=user_attrs["last_name"],
118-
username=user_attrs["username"],
119-
role=user_attrs["role"],
120-
password_hash=hash_password(user_attrs["password"]),
121-
autocommit=False,
122-
)
123-
124-
if workspaces is not None:
125-
for workspace_name in workspaces:
126-
workspace = await Workspace.get_by(db, name=workspace_name)
127-
if not workspace:
128-
raise UnprocessableEntityError(f"Workspace '{workspace_name}' does not exist")
129-
130-
await WorkspaceUser.create(
131-
db,
132-
workspace_id=workspace.id,
133-
user_id=user.id,
134-
autocommit=False,
135-
)
113+
user = await User.create(
114+
db,
115+
first_name=user_attrs["first_name"],
116+
last_name=user_attrs["last_name"],
117+
username=user_attrs["username"],
118+
role=user_attrs["role"],
119+
password_hash=hash_password(user_attrs["password"]),
120+
autocommit=False,
121+
)
122+
123+
if workspaces is not None:
124+
for workspace_name in workspaces:
125+
workspace = await Workspace.get_by(db, name=workspace_name)
126+
if not workspace:
127+
raise UnprocessableEntityError(f"Workspace '{workspace_name}' does not exist")
128+
129+
await WorkspaceUser.create(
130+
db,
131+
workspace_id=workspace.id,
132+
user_id=user.id,
133+
autocommit=False,
134+
)
136135

137136
await db.commit()
138137

0 commit comments

Comments
 (0)