diff --git a/runtime/datamate-python/app/module/collection/interface/collection.py b/runtime/datamate-python/app/module/collection/interface/collection.py index cd596fd7..5ae92ee3 100644 --- a/runtime/datamate-python/app/module/collection/interface/collection.py +++ b/runtime/datamate-python/app/module/collection/interface/collection.py @@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger +from app.db.models import Dataset from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate from app.db.session import get_db from app.module.collection.client.datax_client import DataxClient @@ -40,9 +41,22 @@ async def create_task( DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}") task = convert_for_create(request, task_id) task.template_name = template.name + dataset = None + + if request.dataset_name: + target_dataset_id = uuid.uuid4() + dataset = Dataset( + id=str(target_dataset_id), + name=request.dataset_name, + description="", + dataset_type=request.dataset_type.name, + status="DRAFT", + path=f"/dataset/{target_dataset_id}", + ) + db.add(dataset) task_service = CollectionTaskService(db) - task = await task_service.create_task(task) + task = await task_service.create_task(task, dataset) task = await db.execute(select(CollectionTask).where(CollectionTask.id == task.id)) task = task.scalar_one_or_none() diff --git a/runtime/datamate-python/app/module/collection/schema/collection.py b/runtime/datamate-python/app/module/collection/schema/collection.py index 3781ff88..be7faf76 100644 --- a/runtime/datamate-python/app/module/collection/schema/collection.py +++ b/runtime/datamate-python/app/module/collection/schema/collection.py @@ -8,6 +8,8 @@ from pydantic.alias_generators import to_camel from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate +from app.module.dataset.schema import DatasetTypeResponse +from app.module.dataset.schema.dataset import DatasetType from app.module.shared.schema import TaskStatus @@ -52,6 +54,8 @@ class CollectionTaskCreate(BaseModel): schedule_expression: Optional[str] = Field(None, description="调度表达式(cron)") config: CollectionConfig = Field(..., description="任务配置") template_id: str = Field(..., description="模板ID") + dataset_name: Optional[str] = Field(None, description="数据集名称") + dataset_type: Optional[DatasetType] = Field(DatasetType.TEXT, description="数据集类型") model_config = ConfigDict( alias_generator=to_camel, diff --git a/runtime/datamate-python/app/module/collection/service/collection.py b/runtime/datamate-python/app/module/collection/service/collection.py index a04df140..e2e63128 100644 --- a/runtime/datamate-python/app/module/collection/service/collection.py +++ b/runtime/datamate-python/app/module/collection/service/collection.py @@ -1,15 +1,18 @@ import asyncio from dataclasses import dataclass +from pathlib import Path from typing import Any, Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger +from app.db.models import Dataset from app.db.models.data_collection import CollectionTask, CollectionTemplate from app.db.session import AsyncSessionLocal from app.module.collection.client.datax_client import DataxClient from app.module.collection.schema.collection import SyncMode, create_execute_record +from app.module.dataset.service.service import Service from app.module.shared.schema import TaskStatus logger = get_logger(__name__) @@ -38,18 +41,18 @@ class CollectionTaskService: def __init__(self, db: AsyncSession): self.db = db - async def create_task(self, task: CollectionTask) -> CollectionTask: + async def create_task(self, task: CollectionTask, dataset: Dataset) -> CollectionTask: self.db.add(task) # If it's a one-time task, execute it immediately if task.sync_mode == SyncMode.ONCE: task.status = TaskStatus.RUNNING.name await self.db.commit() - asyncio.create_task(CollectionTaskService.run_async(task.id)) + asyncio.create_task(CollectionTaskService.run_async(task.id, dataset.id if dataset else None)) return task @staticmethod - async def run_async(task_id: str): + async def run_async(task_id: str, dataset_id: str = None): logger.info(f"start to execute task {task_id}") async with AsyncSessionLocal() as session: task = await session.execute(select(CollectionTask).where(CollectionTask.id == task_id)) @@ -69,3 +72,12 @@ async def run_async(task_id: str): DataxClient(execution=task_execution, task=task, template=template).run_datax_job ) await session.commit() + if dataset_id: + dataset_service = Service(db=session) + source_paths = [] + target_path = Path(task.target_path) + if target_path.exists() and target_path.is_dir(): + for file_path in target_path.rglob('*'): + if file_path.is_file(): + source_paths.append(str(file_path.absolute())) + await dataset_service.add_files_to_dataset(dataset_id=dataset_id, source_paths=source_paths) diff --git a/runtime/datamate-python/app/module/dataset/schema/dataset.py b/runtime/datamate-python/app/module/dataset/schema/dataset.py index 8c35e56a..84334d8c 100644 --- a/runtime/datamate-python/app/module/dataset/schema/dataset.py +++ b/runtime/datamate-python/app/module/dataset/schema/dataset.py @@ -1,7 +1,15 @@ +from enum import Enum + from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any from datetime import datetime +class DatasetType(Enum): + TEXT = "TEXT" + IMAGE = "IMAGE" + AUDIO = "AUDIO" + VIDEO = "VIDEO" + class DatasetTypeResponse(BaseModel): """数据集类型响应模型""" code: str = Field(..., description="类型编码") @@ -22,7 +30,7 @@ class DatasetResponse(BaseModel): createdAt: Optional[datetime] = Field(None, description="创建时间") updatedAt: Optional[datetime] = Field(None, description="更新时间") createdBy: Optional[str] = Field(None, description="创建者") - + # 为了向后兼容,添加一个属性方法返回类型对象 @property def type(self) -> DatasetTypeResponse: @@ -33,4 +41,4 @@ def type(self) -> DatasetTypeResponse: description=None, supportedFormats=[], icon=None - ) \ No newline at end of file + ) diff --git a/runtime/datamate-python/app/module/dataset/service/service.py b/runtime/datamate-python/app/module/dataset/service/service.py index 41e2c318..8ccbc6f6 100644 --- a/runtime/datamate-python/app/module/dataset/service/service.py +++ b/runtime/datamate-python/app/module/dataset/service/service.py @@ -1,14 +1,18 @@ import math +import os +import shutil +import asyncio +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional, List, Dict, Any, Coroutine + +from sqlalchemy import func, and_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy import func -from typing import Optional, List, Dict, Any -from datetime import datetime -from app.core.config import settings from app.core.logging import get_logger from app.db.models import Dataset, DatasetFiles - from ..schema import DatasetResponse, PagedDatasetFileResponse, DatasetFileResponse logger = get_logger(__name__) @@ -263,3 +267,110 @@ async def update_file_tags_partial( logger.error(f"Failed to update tags for file {file_id}: {e}") await self.db.rollback() return False, str(e), None + + @staticmethod + async def _get_or_create_dataset_directory(dataset: Dataset) -> str: + """Get or create dataset directory""" + dataset_dir = dataset.path + os.makedirs(dataset_dir, exist_ok=True) + return dataset_dir + + async def add_files_to_dataset(self, dataset_id: str, source_paths: List[str]): + """ + Copy files to dataset directory and create corresponding database records + + Args: + dataset_id: ID of the dataset + source_paths: List of source file paths to copy + + Returns: + List of created dataset file records + """ + logger.info(f"Starting to add files to dataset {dataset_id}") + + try: + # Get dataset and existing files + dataset = await self.db.get(Dataset, dataset_id) + if not dataset: + logger.error(f"Dataset not found: {dataset_id}") + return + + # Get existing files to check for duplicates + result = await self.db.execute( + select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id) + ) + existing_files_map = dict() + for dataset_file in result.scalars().all(): + existing_files_map.__setitem__(dataset_file.file_path, dataset_file) + + # Get or create dataset directory + dataset_dir = await self._get_or_create_dataset_directory(dataset) + + # Process each source file + for source_path in source_paths: + try: + file_record = await self.create_new_dataset_file(dataset_dir, dataset_id, source_path) + if not file_record: + continue + await self.handle_dataset_file(dataset, existing_files_map, file_record, source_path) + + except Exception as e: + logger.error(f"Error processing file {source_path}: {str(e)}", e) + await self.db.rollback() + except Exception as e: + await self.db.rollback() + logger.error(f"Failed to add files to dataset {dataset_id}: {str(e)}", exc_info=True) + + async def handle_dataset_file(self, dataset, existing_files_map: dict[Any, Any], file_record: DatasetFiles, source_path: str): + target_path = file_record.file_path + file_size = file_record.file_size + file_name = file_record.file_name + + # Check for duplicate by filename + if target_path in existing_files_map: + logger.warning(f"File with name {file_name} already exists in dataset {dataset.id}") + dataset_file = existing_files_map.get(target_path) + dataset.size_bytes = dataset.size_bytes - dataset_file.file_size + file_size + dataset.updated_at = datetime.now() + dataset_file.file_size = file_size + dataset_file.updated_at = datetime.now() + else: + # Add to database + self.db.add(file_record) + dataset.file_count = dataset.file_count + 1 + dataset.size_bytes = dataset.size_bytes + file_record.file_size + dataset.updated_at = datetime.now() + dataset.status = 'ACTIVE' + # Copy file + logger.info(f"copy file {source_path} to {target_path}") + dst_dir = os.path.dirname(target_path) + await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True) + await asyncio.to_thread(shutil.copy2, source_path, target_path) + await self.db.commit() + + @staticmethod + async def create_new_dataset_file(dataset_dir: str, dataset_id: str, source_path: str) -> DatasetFiles | None: + source_path_obj = Path(source_path) + + # Check if source exists and is a file + if not source_path_obj.exists() or not source_path_obj.is_file(): + logger.warning(f"Source file does not exist or is not a file: {source_path}") + return None + file_name = source_path_obj.name + file_extension = os.path.splitext(file_name)[1].lstrip('.').lower() + file_size = source_path_obj.stat().st_size + target_path = os.path.join(dataset_dir, file_name) + file_record = DatasetFiles( + id=str(uuid.uuid4()), + dataset_id=dataset_id, + file_name=file_name, + file_type=file_extension or 'other', + file_size=file_size, + file_path=target_path, + upload_time=datetime.now(), + last_access_time=datetime.now(), + status='ACTIVE', + created_at=datetime.now(), + updated_at=datetime.now() + ) + return file_record