Skip to content

Commit fa1c7e6

Browse files
committed
Merge branch 'support/anybase' of https://github.com/apecloud/ApeRAG into support/anybase
2 parents 9d8349e + b1bafa0 commit fa1c7e6

File tree

7 files changed

+179
-81
lines changed

7 files changed

+179
-81
lines changed

aperag/db/repositories/collection.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,13 @@ async def _query(session):
201201
return result.scalars().first()
202202

203203
return await self._execute_query(_query)
204-
205-
async def query_collections_by_ids(self, user: str, collection_ids: List[str]):
204+
205+
async def query_collections_by_ids(self, collection_ids: List[str]):
206206
"""Query multiple collections by their IDs in a single database call"""
207207

208208
async def _query(session):
209209
stmt = select(Collection).where(
210210
Collection.id.in_(collection_ids),
211-
Collection.user == user,
212211
Collection.status != CollectionStatus.DELETED,
213212
)
214213
result = await session.execute(stmt)

aperag/db/repositories/marketplace.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,23 @@ async def _operation(session):
237237

238238
# Marketplace listing operations
239239
async def list_published_collections_with_subscription_status(
240-
self, user_id: str, page: int = 1, page_size: int = 12, user_group_id: Optional[str] = None
240+
self, user_id: str, page: int = 1, page_size: int = 12
241241
) -> Tuple[List[dict], int]:
242242
"""List all published collections accessible to user with subscription status"""
243243

244244
async def _query(session):
245+
# Get user's department information
246+
user_result = await session.execute(
247+
select(User).where(User.id == user_id)
248+
)
249+
user_obj = user_result.scalars().first()
250+
251+
if not user_obj:
252+
return []
253+
254+
user_group_id = user_obj.department_id
245255
# Get user's accessible department IDs
256+
246257
accessible_group_ids = ["*"] # Global collections are always accessible
247258

248259
if user_group_id:
@@ -338,6 +349,137 @@ async def _query(session):
338349

339350
return await self._execute_query(_query)
340351

352+
async def list_all_accessible_collections_for_user(
353+
self, user_id: str
354+
) -> List[dict]:
355+
"""List all collections accessible to user (owned + published accessible collections)"""
356+
357+
async def _query(session):
358+
# Get user's department information
359+
user_result = await session.execute(
360+
select(User).where(User.id == user_id)
361+
)
362+
user_obj = user_result.scalars().first()
363+
364+
if not user_obj:
365+
return []
366+
367+
user_group_id = user_obj.department_id
368+
369+
# Get user's accessible department IDs
370+
accessible_group_ids = ["*"] # Global collections are always accessible
371+
372+
if user_group_id:
373+
# Get user's department hierarchy
374+
dept_result = await session.execute(
375+
select(Department).where(Department.id == user_group_id)
376+
)
377+
user_dept = dept_result.scalars().first()
378+
379+
if user_dept and user_dept.group_path:
380+
# Extract all parent department IDs from group_path
381+
path_parts = user_dept.group_path.strip("/").split("/")
382+
accessible_group_ids.extend([part for part in path_parts if part])
383+
384+
# Also include the user's direct department
385+
accessible_group_ids.append(user_group_id)
386+
387+
# Query for user's own collections
388+
owned_collections_stmt = (
389+
select(
390+
Collection.id.label("id"),
391+
Collection.title,
392+
Collection.description,
393+
Collection.config,
394+
Collection.type,
395+
Collection.status,
396+
Collection.gmt_created,
397+
Collection.gmt_updated,
398+
Collection.user.label("owner_user_id"),
399+
User.username.label("owner_username"),
400+
CollectionMarketplace.id.label("marketplace_id"),
401+
CollectionMarketplace.status.label("marketplace_status"),
402+
CollectionMarketplace.gmt_created.label("published_at"),
403+
CollectionMarketplace.group_id,
404+
)
405+
.select_from(Collection)
406+
.join(User, Collection.user == User.id)
407+
.outerjoin(
408+
CollectionMarketplace,
409+
and_(
410+
CollectionMarketplace.collection_id == Collection.id,
411+
CollectionMarketplace.gmt_deleted.is_(None)
412+
)
413+
)
414+
.where(
415+
Collection.user == user_id,
416+
Collection.status != CollectionStatus.DELETED,
417+
Collection.gmt_deleted.is_(None),
418+
)
419+
)
420+
421+
# Query for published collections accessible to user (not owned by user)
422+
published_collections_stmt = (
423+
select(
424+
Collection.id.label("id"),
425+
Collection.title,
426+
Collection.description,
427+
Collection.config,
428+
Collection.type,
429+
Collection.status,
430+
Collection.gmt_created,
431+
Collection.gmt_updated,
432+
Collection.user.label("owner_user_id"),
433+
User.username.label("owner_username"),
434+
CollectionMarketplace.id.label("marketplace_id"),
435+
CollectionMarketplace.status.label("marketplace_status"),
436+
CollectionMarketplace.gmt_created.label("published_at"),
437+
CollectionMarketplace.group_id,
438+
)
439+
.select_from(CollectionMarketplace)
440+
.join(Collection, CollectionMarketplace.collection_id == Collection.id)
441+
.join(User, Collection.user == User.id)
442+
.where(
443+
CollectionMarketplace.status == CollectionMarketplaceStatusEnum.PUBLISHED.value,
444+
CollectionMarketplace.gmt_deleted.is_(None),
445+
Collection.status != CollectionStatus.DELETED,
446+
Collection.gmt_deleted.is_(None),
447+
CollectionMarketplace.group_id.in_(accessible_group_ids),
448+
Collection.user != user_id, # Exclude user's own collections to avoid duplicates
449+
)
450+
)
451+
452+
# Combine both queries using UNION ALL
453+
combined_stmt = owned_collections_stmt.union_all(published_collections_stmt)
454+
455+
# Execute the combined query
456+
result = await session.execute(combined_stmt)
457+
rows = result.all()
458+
459+
# Convert rows to dictionaries for easier access
460+
collections = []
461+
for row in rows:
462+
collections.append({
463+
'id': row.id,
464+
'title': row.title,
465+
'description': row.description,
466+
'config': row.config,
467+
'type': row.type,
468+
'status': row.status,
469+
'gmt_created': row.gmt_created,
470+
'gmt_updated': row.gmt_updated,
471+
'owner_user_id': row.owner_user_id,
472+
'owner_username': row.owner_username,
473+
'marketplace_id': row.marketplace_id,
474+
'marketplace_status': row.marketplace_status,
475+
'published_at': row.published_at,
476+
'group_id': row.group_id,
477+
})
478+
479+
return collections
480+
481+
return await self._execute_query(_query)
482+
341483
# Subscription operations
342484
async def create_subscription(self, user_id: str, collection_marketplace_id: str) -> UserCollectionSubscription:
343485
"""Create a new subscription"""

aperag/service/agent_chat_service.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@
5353
from aperag.agent.response_types import AgentErrorResponse, AgentToolCallResultResponse
5454
from aperag.chat.history.message import StoredChatMessage, create_assistant_message
5555
from aperag.db.ops import AsyncDatabaseOps, async_db_ops
56+
from aperag.exceptions import ResourceNotFoundException
5657
from aperag.schema import view_models
58+
from aperag.schema.utils import parseCollectionConfig
5759
from aperag.service.prompt_template_service import build_agent_query_prompt, get_agent_system_prompt
5860
from aperag.trace import trace_async_function
5961

@@ -100,25 +102,6 @@ def __init__(self, session: AsyncSession = None):
100102
self.memory_manager = AgentMemoryManager()
101103
self.history_manager = AgentHistoryManager()
102104

103-
async def _convert_db_collections_to_pydantic(self, db_collections) -> List[view_models.Collection]:
104-
"""Convert SQLAlchemy Collection models to Pydantic Collection models"""
105-
from aperag.schema.utils import parseCollectionConfig
106-
107-
pydantic_collections = []
108-
for db_collection in db_collections:
109-
pydantic_collection = view_models.Collection(
110-
id=db_collection.id,
111-
title=db_collection.title,
112-
description=db_collection.description,
113-
type=db_collection.type,
114-
status=getattr(db_collection, "status", None),
115-
config=parseCollectionConfig(db_collection.config),
116-
created=db_collection.gmt_created.isoformat(),
117-
updated=db_collection.gmt_updated.isoformat(),
118-
)
119-
pydantic_collections.append(pydantic_collection)
120-
return pydantic_collections
121-
122105
def _parse_websocket_message(
123106
self, raw_data: str
124107
) -> Tuple[Optional[view_models.AgentMessage], Optional[AgentErrorResponse]]:
@@ -191,10 +174,19 @@ async def handle_websocket_agent_chat(self, websocket: WebSocket, user: str, bot
191174

192175
# Get default collections once for performance
193176
if bot_config.agent.collections:
194-
collection_ids = [collection.id for collection in bot_config.agent.collections]
195-
db_collections = await self.db_ops.query_collections_by_ids(user, collection_ids)
196-
# Convert SQLAlchemy models to Pydantic models
197-
default_collections = await self._convert_db_collections_to_pydantic(db_collections)
177+
agent_collection_ids = [collection.id for collection in bot_config.agent.collections]
178+
agent_collections = await self.db_ops.query_collections_by_ids(agent_collection_ids)
179+
for agent_collection in agent_collections:
180+
default_collections.append(view_models.Collection(
181+
id=agent_collection.id,
182+
title=agent_collection.title,
183+
description=agent_collection.description,
184+
type=agent_collection.type,
185+
status=agent_collection.status,
186+
config=parseCollectionConfig(agent_collection.config),
187+
created=agent_collection.gmt_created.isoformat(),
188+
updated=agent_collection.gmt_updated.isoformat(),
189+
))
198190

199191
while True:
200192
# Receive message from WebSocket

aperag/service/bot_service.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from aperag.schema import view_models
2626
from aperag.schema.view_models import Bot, BotList
27+
from aperag.service import collection_service, marketplace_service
2728
from aperag.service.quota_service import quota_service
2829

2930

@@ -49,18 +50,20 @@ async def build_bot_response(self, bot: db_models.Bot) -> view_models.Bot:
4950
id=bot.id,
5051
title=bot.title,
5152
description=bot.description,
52-
type=bot.type,
53+
type=bot.type,
5354
config=bot_config,
5455
created=bot.gmt_created.isoformat(),
5556
updated=bot.gmt_updated.isoformat(),
5657
)
5758

5859
async def validate_collections(self, user: str, bot_config: view_models.BotConfig):
60+
user_collections = await self.db_ops.list_all_accessible_collections_for_user(user)
61+
user_collections_ids = [collection['id'] for collection in user_collections]
5962
if bot_config and bot_config.agent and bot_config.agent.collections:
60-
collection_ids = [collection.id for collection in bot_config.agent.collections]
61-
collections = await self.db_ops.query_collections_by_ids(user, collection_ids)
62-
if not collections or len(collections) != len(collection_ids):
63-
raise ResourceNotFoundException("Collection", collection_ids)
63+
agent_collection_ids = [collection.id for collection in bot_config.agent.collections]
64+
for collection_id in agent_collection_ids:
65+
if collection_id not in user_collections_ids:
66+
raise ResourceNotFoundException("Collection", collection_id)
6467

6568
async def create_bot(
6669
self, user: str, bot_in: view_models.BotCreate, skip_quota_check: bool = False

aperag/service/collection_service.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -616,49 +616,6 @@ async def delete_search(self, user: str, collection_id: str, search_id: str) ->
616616

617617
return await self.db_ops.delete_search(user, collection_id, search_id)
618618

619-
async def validate_collections_batch(
620-
self, user: str, collections: list[view_models.Collection]
621-
) -> tuple[bool, str]:
622-
"""
623-
Validate multiple collections in a single database call.
624-
625-
Args:
626-
user: User identifier
627-
collections: List of collection objects to validate
628-
629-
Returns:
630-
Tuple of (is_valid, error_message). If valid, error_message is empty.
631-
"""
632-
if not collections:
633-
return True, ""
634-
635-
# Extract collection IDs and validate they exist
636-
collection_ids = []
637-
for collection in collections:
638-
if not collection.id:
639-
return False, "Collection object missing 'id' field"
640-
collection_ids.append(collection.id)
641-
642-
# Remove duplicates while preserving order
643-
unique_collection_ids = list(dict.fromkeys(collection_ids))
644-
645-
try:
646-
# Single database call to get all collections
647-
db_collections = await self.db_ops.query_collections_by_ids(user, unique_collection_ids)
648-
649-
# Create a set of found collection IDs for fast lookup
650-
found_collection_ids = {str(col.id) for col in db_collections}
651-
652-
# Check if all requested collections were found
653-
for collection_id in unique_collection_ids:
654-
if collection_id not in found_collection_ids:
655-
return False, f"Collection {collection_id} not found"
656-
657-
return True, ""
658-
659-
except Exception as e:
660-
return False, f"Failed to validate collections: {str(e)}"
661-
662619
async def test_mineru_token(self, token: str) -> dict:
663620
"""Test the MinerU API token."""
664621
async with httpx.AsyncClient() as client:

envs/env.template

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@ POSTGRES_PASSWORD=postgres
77

88
# Database Connection Pool Settings
99
# Adjust these values based on your server resources and expected load
10-
DB_POOL_SIZE=20 # Base connection pool size
11-
DB_MAX_OVERFLOW=40 # Maximum overflow connections (total = pool_size + max_overflow)
12-
DB_POOL_TIMEOUT=60 # Connection timeout in seconds
13-
DB_POOL_RECYCLE=3600 # Recycle connections after 1 hour (in seconds)
14-
DB_POOL_PRE_PING=True # Validate connections before use
10+
# Base connection pool size
11+
DB_POOL_SIZE=20
12+
# Maximum overflow connections (total = pool_size + max_overflow)
13+
DB_MAX_OVERFLOW=40
14+
# Connection timeout in seconds
15+
DB_POOL_TIMEOUT=60
16+
# Recycle connections after 1 hour (in seconds)
17+
DB_POOL_RECYCLE=3600
18+
# Validate connections before use
19+
DB_POOL_PRE_PING=True
1520

1621
# Redis
1722
REDIS_HOST=127.0.0.1

web/src/app/workspace/collections/[collectionId]/documents/upload/document-upload.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ export const DocumentUpload = () => {
378378
<>
379379
<FileUpload
380380
maxFiles={1000}
381-
maxSize={10 * 1024 * 1024}
381+
maxSize={100 * 1024 * 1024}
382382
className="w-full gap-4"
383383
accept=".pdf,.doc,.docx,.txt,.md,.ppt,.pptx,.xls,.xlsx"
384384
value={documents.map((f) => f.file)}

0 commit comments

Comments
 (0)