Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.logging import get_logger
from app.db.models import Dataset
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
from app.db.session import get_db
from app.module.collection.client.datax_client import DataxClient
Expand Down Expand Up @@ -40,9 +41,22 @@ async def create_task(
DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}")
task = convert_for_create(request, task_id)
task.template_name = template.name
dataset = None

if request.dataset_name:
target_dataset_id = uuid.uuid4()
dataset = Dataset(
id=str(target_dataset_id),
name=request.dataset_name,
description="",
dataset_type=request.dataset_type.name,
status="DRAFT",
path=f"/dataset/{target_dataset_id}",
)
db.add(dataset)

task_service = CollectionTaskService(db)
task = await task_service.create_task(task)
task = await task_service.create_task(task, dataset)

task = await db.execute(select(CollectionTask).where(CollectionTask.id == task.id))
task = task.scalar_one_or_none()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pydantic.alias_generators import to_camel

from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
from app.module.dataset.schema import DatasetTypeResponse
from app.module.dataset.schema.dataset import DatasetType
from app.module.shared.schema import TaskStatus


Expand Down Expand Up @@ -52,6 +54,8 @@ class CollectionTaskCreate(BaseModel):
schedule_expression: Optional[str] = Field(None, description="调度表达式(cron)")
config: CollectionConfig = Field(..., description="任务配置")
template_id: str = Field(..., description="模板ID")
dataset_name: Optional[str] = Field(None, description="数据集名称")
dataset_type: Optional[DatasetType] = Field(DatasetType.TEXT, description="数据集类型")

model_config = ConfigDict(
alias_generator=to_camel,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import asyncio
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.logging import get_logger
from app.db.models import Dataset
from app.db.models.data_collection import CollectionTask, CollectionTemplate
from app.db.session import AsyncSessionLocal
from app.module.collection.client.datax_client import DataxClient
from app.module.collection.schema.collection import SyncMode, create_execute_record
from app.module.dataset.service.service import Service
from app.module.shared.schema import TaskStatus

logger = get_logger(__name__)
Expand Down Expand Up @@ -38,18 +41,18 @@ class CollectionTaskService:
def __init__(self, db: AsyncSession):
self.db = db

async def create_task(self, task: CollectionTask) -> CollectionTask:
async def create_task(self, task: CollectionTask, dataset: Dataset) -> CollectionTask:
self.db.add(task)

# If it's a one-time task, execute it immediately
if task.sync_mode == SyncMode.ONCE:
task.status = TaskStatus.RUNNING.name
await self.db.commit()
asyncio.create_task(CollectionTaskService.run_async(task.id))
asyncio.create_task(CollectionTaskService.run_async(task.id, dataset.id if dataset else None))
return task

@staticmethod
async def run_async(task_id: str):
async def run_async(task_id: str, dataset_id: str = None):
logger.info(f"start to execute task {task_id}")
async with AsyncSessionLocal() as session:
task = await session.execute(select(CollectionTask).where(CollectionTask.id == task_id))
Expand All @@ -69,3 +72,12 @@ async def run_async(task_id: str):
DataxClient(execution=task_execution, task=task, template=template).run_datax_job
)
await session.commit()
if dataset_id:
dataset_service = Service(db=session)
source_paths = []
target_path = Path(task.target_path)
if target_path.exists() and target_path.is_dir():
for file_path in target_path.rglob('*'):
if file_path.is_file():
source_paths.append(str(file_path.absolute()))
await dataset_service.add_files_to_dataset(dataset_id=dataset_id, source_paths=source_paths)
12 changes: 10 additions & 2 deletions runtime/datamate-python/app/module/dataset/schema/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from enum import Enum

from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime

class DatasetType(Enum):
TEXT = "TEXT"
IMAGE = "IMAGE"
AUDIO = "AUDIO"
VIDEO = "VIDEO"

class DatasetTypeResponse(BaseModel):
"""数据集类型响应模型"""
code: str = Field(..., description="类型编码")
Expand All @@ -22,7 +30,7 @@ class DatasetResponse(BaseModel):
createdAt: Optional[datetime] = Field(None, description="创建时间")
updatedAt: Optional[datetime] = Field(None, description="更新时间")
createdBy: Optional[str] = Field(None, description="创建者")

# 为了向后兼容,添加一个属性方法返回类型对象
@property
def type(self) -> DatasetTypeResponse:
Expand All @@ -33,4 +41,4 @@ def type(self) -> DatasetTypeResponse:
description=None,
supportedFormats=[],
icon=None
)
)
121 changes: 116 additions & 5 deletions runtime/datamate-python/app/module/dataset/service/service.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import math
import os
import shutil
import asyncio
import uuid
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any, Coroutine

from sqlalchemy import func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func
from typing import Optional, List, Dict, Any
from datetime import datetime

from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import Dataset, DatasetFiles

from ..schema import DatasetResponse, PagedDatasetFileResponse, DatasetFileResponse

logger = get_logger(__name__)
Expand Down Expand Up @@ -263,3 +267,110 @@ async def update_file_tags_partial(
logger.error(f"Failed to update tags for file {file_id}: {e}")
await self.db.rollback()
return False, str(e), None

@staticmethod
async def _get_or_create_dataset_directory(dataset: Dataset) -> str:
"""Get or create dataset directory"""
dataset_dir = dataset.path
os.makedirs(dataset_dir, exist_ok=True)
return dataset_dir

async def add_files_to_dataset(self, dataset_id: str, source_paths: List[str]):
"""
Copy files to dataset directory and create corresponding database records

Args:
dataset_id: ID of the dataset
source_paths: List of source file paths to copy

Returns:
List of created dataset file records
"""
logger.info(f"Starting to add files to dataset {dataset_id}")

try:
# Get dataset and existing files
dataset = await self.db.get(Dataset, dataset_id)
if not dataset:
logger.error(f"Dataset not found: {dataset_id}")
return

# Get existing files to check for duplicates
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
)
existing_files_map = dict()
for dataset_file in result.scalars().all():
existing_files_map.__setitem__(dataset_file.file_path, dataset_file)

# Get or create dataset directory
dataset_dir = await self._get_or_create_dataset_directory(dataset)

# Process each source file
for source_path in source_paths:
try:
file_record = await self.create_new_dataset_file(dataset_dir, dataset_id, source_path)
if not file_record:
continue
await self.handle_dataset_file(dataset, existing_files_map, file_record, source_path)

except Exception as e:
logger.error(f"Error processing file {source_path}: {str(e)}", e)
await self.db.rollback()
except Exception as e:
await self.db.rollback()
logger.error(f"Failed to add files to dataset {dataset_id}: {str(e)}", exc_info=True)

async def handle_dataset_file(self, dataset, existing_files_map: dict[Any, Any], file_record: DatasetFiles, source_path: str):
target_path = file_record.file_path
file_size = file_record.file_size
file_name = file_record.file_name

# Check for duplicate by filename
if target_path in existing_files_map:
logger.warning(f"File with name {file_name} already exists in dataset {dataset.id}")
dataset_file = existing_files_map.get(target_path)
dataset.size_bytes = dataset.size_bytes - dataset_file.file_size + file_size
dataset.updated_at = datetime.now()
dataset_file.file_size = file_size
dataset_file.updated_at = datetime.now()
else:
# Add to database
self.db.add(file_record)
dataset.file_count = dataset.file_count + 1
dataset.size_bytes = dataset.size_bytes + file_record.file_size
dataset.updated_at = datetime.now()
dataset.status = 'ACTIVE'
# Copy file
logger.info(f"copy file {source_path} to {target_path}")
dst_dir = os.path.dirname(target_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, source_path, target_path)
await self.db.commit()

@staticmethod
async def create_new_dataset_file(dataset_dir: str, dataset_id: str, source_path: str) -> DatasetFiles | None:
source_path_obj = Path(source_path)

# Check if source exists and is a file
if not source_path_obj.exists() or not source_path_obj.is_file():
logger.warning(f"Source file does not exist or is not a file: {source_path}")
return None
file_name = source_path_obj.name
file_extension = os.path.splitext(file_name)[1].lstrip('.').lower()
file_size = source_path_obj.stat().st_size
target_path = os.path.join(dataset_dir, file_name)
file_record = DatasetFiles(
id=str(uuid.uuid4()),
dataset_id=dataset_id,
file_name=file_name,
file_type=file_extension or 'other',
file_size=file_size,
file_path=target_path,
upload_time=datetime.now(),
last_access_time=datetime.now(),
status='ACTIVE',
created_at=datetime.now(),
updated_at=datetime.now()
)
return file_record
Loading