Skip to content

Commit 5bd6024

Browse files
authored
feature: could create dataset while creating collection task (#228)
1 parent f541e1c commit 5bd6024

File tree

5 files changed

+160
-11
lines changed

5 files changed

+160
-11
lines changed

runtime/datamate-python/app/module/collection/interface/collection.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqlalchemy.ext.asyncio import AsyncSession
1010

1111
from app.core.logging import get_logger
12+
from app.db.models import Dataset
1213
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
1314
from app.db.session import get_db
1415
from app.module.collection.client.datax_client import DataxClient
@@ -40,9 +41,22 @@ async def create_task(
4041
DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}")
4142
task = convert_for_create(request, task_id)
4243
task.template_name = template.name
44+
dataset = None
45+
46+
if request.dataset_name:
47+
target_dataset_id = uuid.uuid4()
48+
dataset = Dataset(
49+
id=str(target_dataset_id),
50+
name=request.dataset_name,
51+
description="",
52+
dataset_type=request.dataset_type.name,
53+
status="DRAFT",
54+
path=f"/dataset/{target_dataset_id}",
55+
)
56+
db.add(dataset)
4357

4458
task_service = CollectionTaskService(db)
45-
task = await task_service.create_task(task)
59+
task = await task_service.create_task(task, dataset)
4660

4761
task = await db.execute(select(CollectionTask).where(CollectionTask.id == task.id))
4862
task = task.scalar_one_or_none()

runtime/datamate-python/app/module/collection/schema/collection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from pydantic.alias_generators import to_camel
99

1010
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
11+
from app.module.dataset.schema import DatasetTypeResponse
12+
from app.module.dataset.schema.dataset import DatasetType
1113
from app.module.shared.schema import TaskStatus
1214

1315

@@ -52,6 +54,8 @@ class CollectionTaskCreate(BaseModel):
5254
schedule_expression: Optional[str] = Field(None, description="调度表达式(cron)")
5355
config: CollectionConfig = Field(..., description="任务配置")
5456
template_id: str = Field(..., description="模板ID")
57+
dataset_name: Optional[str] = Field(None, description="数据集名称")
58+
dataset_type: Optional[DatasetType] = Field(DatasetType.TEXT, description="数据集类型")
5559

5660
model_config = ConfigDict(
5761
alias_generator=to_camel,

runtime/datamate-python/app/module/collection/service/collection.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import asyncio
22
from dataclasses import dataclass
3+
from pathlib import Path
34
from typing import Any, Optional
45

56
from sqlalchemy import select
67
from sqlalchemy.ext.asyncio import AsyncSession
78

89
from app.core.logging import get_logger
10+
from app.db.models import Dataset
911
from app.db.models.data_collection import CollectionTask, CollectionTemplate
1012
from app.db.session import AsyncSessionLocal
1113
from app.module.collection.client.datax_client import DataxClient
1214
from app.module.collection.schema.collection import SyncMode, create_execute_record
15+
from app.module.dataset.service.service import Service
1316
from app.module.shared.schema import TaskStatus
1417

1518
logger = get_logger(__name__)
@@ -38,18 +41,18 @@ class CollectionTaskService:
3841
def __init__(self, db: AsyncSession):
3942
self.db = db
4043

41-
async def create_task(self, task: CollectionTask) -> CollectionTask:
44+
async def create_task(self, task: CollectionTask, dataset: Dataset) -> CollectionTask:
4245
self.db.add(task)
4346

4447
# If it's a one-time task, execute it immediately
4548
if task.sync_mode == SyncMode.ONCE:
4649
task.status = TaskStatus.RUNNING.name
4750
await self.db.commit()
48-
asyncio.create_task(CollectionTaskService.run_async(task.id))
51+
asyncio.create_task(CollectionTaskService.run_async(task.id, dataset.id if dataset else None))
4952
return task
5053

5154
@staticmethod
52-
async def run_async(task_id: str):
55+
async def run_async(task_id: str, dataset_id: str = None):
5356
logger.info(f"start to execute task {task_id}")
5457
async with AsyncSessionLocal() as session:
5558
task = await session.execute(select(CollectionTask).where(CollectionTask.id == task_id))
@@ -69,3 +72,12 @@ async def run_async(task_id: str):
6972
DataxClient(execution=task_execution, task=task, template=template).run_datax_job
7073
)
7174
await session.commit()
75+
if dataset_id:
76+
dataset_service = Service(db=session)
77+
source_paths = []
78+
target_path = Path(task.target_path)
79+
if target_path.exists() and target_path.is_dir():
80+
for file_path in target_path.rglob('*'):
81+
if file_path.is_file():
82+
source_paths.append(str(file_path.absolute()))
83+
await dataset_service.add_files_to_dataset(dataset_id=dataset_id, source_paths=source_paths)

runtime/datamate-python/app/module/dataset/schema/dataset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
from enum import Enum
2+
13
from pydantic import BaseModel, Field
24
from typing import List, Optional, Dict, Any
35
from datetime import datetime
46

7+
class DatasetType(Enum):
8+
TEXT = "TEXT"
9+
IMAGE = "IMAGE"
10+
AUDIO = "AUDIO"
11+
VIDEO = "VIDEO"
12+
513
class DatasetTypeResponse(BaseModel):
614
"""数据集类型响应模型"""
715
code: str = Field(..., description="类型编码")
@@ -22,7 +30,7 @@ class DatasetResponse(BaseModel):
2230
createdAt: Optional[datetime] = Field(None, description="创建时间")
2331
updatedAt: Optional[datetime] = Field(None, description="更新时间")
2432
createdBy: Optional[str] = Field(None, description="创建者")
25-
33+
2634
# 为了向后兼容,添加一个属性方法返回类型对象
2735
@property
2836
def type(self) -> DatasetTypeResponse:
@@ -33,4 +41,4 @@ def type(self) -> DatasetTypeResponse:
3341
description=None,
3442
supportedFormats=[],
3543
icon=None
36-
)
44+
)

runtime/datamate-python/app/module/dataset/service/service.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import math
2+
import os
3+
import shutil
4+
import asyncio
5+
import uuid
6+
from datetime import datetime
7+
from pathlib import Path
8+
from typing import Optional, List, Dict, Any, Coroutine
9+
10+
from sqlalchemy import func, and_
211
from sqlalchemy.ext.asyncio import AsyncSession
312
from sqlalchemy.future import select
4-
from sqlalchemy import func
5-
from typing import Optional, List, Dict, Any
6-
from datetime import datetime
713

8-
from app.core.config import settings
914
from app.core.logging import get_logger
1015
from app.db.models import Dataset, DatasetFiles
11-
1216
from ..schema import DatasetResponse, PagedDatasetFileResponse, DatasetFileResponse
1317

1418
logger = get_logger(__name__)
@@ -263,3 +267,110 @@ async def update_file_tags_partial(
263267
logger.error(f"Failed to update tags for file {file_id}: {e}")
264268
await self.db.rollback()
265269
return False, str(e), None
270+
271+
@staticmethod
272+
async def _get_or_create_dataset_directory(dataset: Dataset) -> str:
273+
"""Get or create dataset directory"""
274+
dataset_dir = dataset.path
275+
os.makedirs(dataset_dir, exist_ok=True)
276+
return dataset_dir
277+
278+
async def add_files_to_dataset(self, dataset_id: str, source_paths: List[str]):
279+
"""
280+
Copy files to dataset directory and create corresponding database records
281+
282+
Args:
283+
dataset_id: ID of the dataset
284+
source_paths: List of source file paths to copy
285+
286+
Returns:
287+
List of created dataset file records
288+
"""
289+
logger.info(f"Starting to add files to dataset {dataset_id}")
290+
291+
try:
292+
# Get dataset and existing files
293+
dataset = await self.db.get(Dataset, dataset_id)
294+
if not dataset:
295+
logger.error(f"Dataset not found: {dataset_id}")
296+
return
297+
298+
# Get existing files to check for duplicates
299+
result = await self.db.execute(
300+
select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
301+
)
302+
existing_files_map = dict()
303+
for dataset_file in result.scalars().all():
304+
existing_files_map.__setitem__(dataset_file.file_path, dataset_file)
305+
306+
# Get or create dataset directory
307+
dataset_dir = await self._get_or_create_dataset_directory(dataset)
308+
309+
# Process each source file
310+
for source_path in source_paths:
311+
try:
312+
file_record = await self.create_new_dataset_file(dataset_dir, dataset_id, source_path)
313+
if not file_record:
314+
continue
315+
await self.handle_dataset_file(dataset, existing_files_map, file_record, source_path)
316+
317+
except Exception as e:
318+
logger.error(f"Error processing file {source_path}: {str(e)}", e)
319+
await self.db.rollback()
320+
except Exception as e:
321+
await self.db.rollback()
322+
logger.error(f"Failed to add files to dataset {dataset_id}: {str(e)}", exc_info=True)
323+
324+
async def handle_dataset_file(self, dataset, existing_files_map: dict[Any, Any], file_record: DatasetFiles, source_path: str):
325+
target_path = file_record.file_path
326+
file_size = file_record.file_size
327+
file_name = file_record.file_name
328+
329+
# Check for duplicate by filename
330+
if target_path in existing_files_map:
331+
logger.warning(f"File with name {file_name} already exists in dataset {dataset.id}")
332+
dataset_file = existing_files_map.get(target_path)
333+
dataset.size_bytes = dataset.size_bytes - dataset_file.file_size + file_size
334+
dataset.updated_at = datetime.now()
335+
dataset_file.file_size = file_size
336+
dataset_file.updated_at = datetime.now()
337+
else:
338+
# Add to database
339+
self.db.add(file_record)
340+
dataset.file_count = dataset.file_count + 1
341+
dataset.size_bytes = dataset.size_bytes + file_record.file_size
342+
dataset.updated_at = datetime.now()
343+
dataset.status = 'ACTIVE'
344+
# Copy file
345+
logger.info(f"copy file {source_path} to {target_path}")
346+
dst_dir = os.path.dirname(target_path)
347+
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
348+
await asyncio.to_thread(shutil.copy2, source_path, target_path)
349+
await self.db.commit()
350+
351+
@staticmethod
352+
async def create_new_dataset_file(dataset_dir: str, dataset_id: str, source_path: str) -> DatasetFiles | None:
353+
source_path_obj = Path(source_path)
354+
355+
# Check if source exists and is a file
356+
if not source_path_obj.exists() or not source_path_obj.is_file():
357+
logger.warning(f"Source file does not exist or is not a file: {source_path}")
358+
return None
359+
file_name = source_path_obj.name
360+
file_extension = os.path.splitext(file_name)[1].lstrip('.').lower()
361+
file_size = source_path_obj.stat().st_size
362+
target_path = os.path.join(dataset_dir, file_name)
363+
file_record = DatasetFiles(
364+
id=str(uuid.uuid4()),
365+
dataset_id=dataset_id,
366+
file_name=file_name,
367+
file_type=file_extension or 'other',
368+
file_size=file_size,
369+
file_path=target_path,
370+
upload_time=datetime.now(),
371+
last_access_time=datetime.now(),
372+
status='ACTIVE',
373+
created_at=datetime.now(),
374+
updated_at=datetime.now()
375+
)
376+
return file_record

0 commit comments

Comments
 (0)