|
1 | 1 | import math |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +import asyncio |
| 5 | +import uuid |
| 6 | +from datetime import datetime |
| 7 | +from pathlib import Path |
| 8 | +from typing import Optional, List, Dict, Any, Coroutine |
| 9 | + |
| 10 | +from sqlalchemy import func, and_ |
2 | 11 | from sqlalchemy.ext.asyncio import AsyncSession |
3 | 12 | from sqlalchemy.future import select |
4 | | -from sqlalchemy import func |
5 | | -from typing import Optional, List, Dict, Any |
6 | | -from datetime import datetime |
7 | 13 |
|
8 | | -from app.core.config import settings |
9 | 14 | from app.core.logging import get_logger |
10 | 15 | from app.db.models import Dataset, DatasetFiles |
11 | | - |
12 | 16 | from ..schema import DatasetResponse, PagedDatasetFileResponse, DatasetFileResponse |
13 | 17 |
|
14 | 18 | logger = get_logger(__name__) |
@@ -263,3 +267,110 @@ async def update_file_tags_partial( |
263 | 267 | logger.error(f"Failed to update tags for file {file_id}: {e}") |
264 | 268 | await self.db.rollback() |
265 | 269 | return False, str(e), None |
| 270 | + |
| 271 | + @staticmethod |
| 272 | + async def _get_or_create_dataset_directory(dataset: Dataset) -> str: |
| 273 | + """Get or create dataset directory""" |
| 274 | + dataset_dir = dataset.path |
| 275 | + os.makedirs(dataset_dir, exist_ok=True) |
| 276 | + return dataset_dir |
| 277 | + |
| 278 | + async def add_files_to_dataset(self, dataset_id: str, source_paths: List[str]): |
| 279 | + """ |
| 280 | + Copy files to dataset directory and create corresponding database records |
| 281 | +
|
| 282 | + Args: |
| 283 | + dataset_id: ID of the dataset |
| 284 | + source_paths: List of source file paths to copy |
| 285 | +
|
| 286 | + Returns: |
| 287 | + List of created dataset file records |
| 288 | + """ |
| 289 | + logger.info(f"Starting to add files to dataset {dataset_id}") |
| 290 | + |
| 291 | + try: |
| 292 | + # Get dataset and existing files |
| 293 | + dataset = await self.db.get(Dataset, dataset_id) |
| 294 | + if not dataset: |
| 295 | + logger.error(f"Dataset not found: {dataset_id}") |
| 296 | + return |
| 297 | + |
| 298 | + # Get existing files to check for duplicates |
| 299 | + result = await self.db.execute( |
| 300 | + select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id) |
| 301 | + ) |
| 302 | + existing_files_map = dict() |
| 303 | + for dataset_file in result.scalars().all(): |
| 304 | + existing_files_map.__setitem__(dataset_file.file_path, dataset_file) |
| 305 | + |
| 306 | + # Get or create dataset directory |
| 307 | + dataset_dir = await self._get_or_create_dataset_directory(dataset) |
| 308 | + |
| 309 | + # Process each source file |
| 310 | + for source_path in source_paths: |
| 311 | + try: |
| 312 | + file_record = await self.create_new_dataset_file(dataset_dir, dataset_id, source_path) |
| 313 | + if not file_record: |
| 314 | + continue |
| 315 | + await self.handle_dataset_file(dataset, existing_files_map, file_record, source_path) |
| 316 | + |
| 317 | + except Exception as e: |
| 318 | + logger.error(f"Error processing file {source_path}: {str(e)}", e) |
| 319 | + await self.db.rollback() |
| 320 | + except Exception as e: |
| 321 | + await self.db.rollback() |
| 322 | + logger.error(f"Failed to add files to dataset {dataset_id}: {str(e)}", exc_info=True) |
| 323 | + |
| 324 | + async def handle_dataset_file(self, dataset, existing_files_map: dict[Any, Any], file_record: DatasetFiles, source_path: str): |
| 325 | + target_path = file_record.file_path |
| 326 | + file_size = file_record.file_size |
| 327 | + file_name = file_record.file_name |
| 328 | + |
| 329 | + # Check for duplicate by filename |
| 330 | + if target_path in existing_files_map: |
| 331 | + logger.warning(f"File with name {file_name} already exists in dataset {dataset.id}") |
| 332 | + dataset_file = existing_files_map.get(target_path) |
| 333 | + dataset.size_bytes = dataset.size_bytes - dataset_file.file_size + file_size |
| 334 | + dataset.updated_at = datetime.now() |
| 335 | + dataset_file.file_size = file_size |
| 336 | + dataset_file.updated_at = datetime.now() |
| 337 | + else: |
| 338 | + # Add to database |
| 339 | + self.db.add(file_record) |
| 340 | + dataset.file_count = dataset.file_count + 1 |
| 341 | + dataset.size_bytes = dataset.size_bytes + file_record.file_size |
| 342 | + dataset.updated_at = datetime.now() |
| 343 | + dataset.status = 'ACTIVE' |
| 344 | + # Copy file |
| 345 | + logger.info(f"copy file {source_path} to {target_path}") |
| 346 | + dst_dir = os.path.dirname(target_path) |
| 347 | + await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True) |
| 348 | + await asyncio.to_thread(shutil.copy2, source_path, target_path) |
| 349 | + await self.db.commit() |
| 350 | + |
| 351 | + @staticmethod |
| 352 | + async def create_new_dataset_file(dataset_dir: str, dataset_id: str, source_path: str) -> DatasetFiles | None: |
| 353 | + source_path_obj = Path(source_path) |
| 354 | + |
| 355 | + # Check if source exists and is a file |
| 356 | + if not source_path_obj.exists() or not source_path_obj.is_file(): |
| 357 | + logger.warning(f"Source file does not exist or is not a file: {source_path}") |
| 358 | + return None |
| 359 | + file_name = source_path_obj.name |
| 360 | + file_extension = os.path.splitext(file_name)[1].lstrip('.').lower() |
| 361 | + file_size = source_path_obj.stat().st_size |
| 362 | + target_path = os.path.join(dataset_dir, file_name) |
| 363 | + file_record = DatasetFiles( |
| 364 | + id=str(uuid.uuid4()), |
| 365 | + dataset_id=dataset_id, |
| 366 | + file_name=file_name, |
| 367 | + file_type=file_extension or 'other', |
| 368 | + file_size=file_size, |
| 369 | + file_path=target_path, |
| 370 | + upload_time=datetime.now(), |
| 371 | + last_access_time=datetime.now(), |
| 372 | + status='ACTIVE', |
| 373 | + created_at=datetime.now(), |
| 374 | + updated_at=datetime.now() |
| 375 | + ) |
| 376 | + return file_record |
0 commit comments