diff --git a/runtime/datamate-python/app/module/dataset/schema/dataset_file.py b/runtime/datamate-python/app/module/dataset/schema/dataset_file.py index f31bba84..e387f9a0 100644 --- a/runtime/datamate-python/app/module/dataset/schema/dataset_file.py +++ b/runtime/datamate-python/app/module/dataset/schema/dataset_file.py @@ -46,7 +46,7 @@ def get_tags(self) -> List[str]: tags.append(tag_values) # 如果 from_name 不为空,添加前缀 if self.from_name: - tags = [f"{self.from_name}@{tag}" for tag in tags] + tags = [{"label": self.from_name, "value": tag} for tag in tags] return tags diff --git a/runtime/datamate-python/app/module/ratio/service/ratio_task.py b/runtime/datamate-python/app/module/ratio/service/ratio_task.py index 943d23b0..be8d001d 100644 --- a/runtime/datamate-python/app/module/ratio/service/ratio_task.py +++ b/runtime/datamate-python/app/module/ratio/service/ratio_task.py @@ -61,7 +61,7 @@ async def create_task( 'label': { "label":item.get("filter_conditions").label.label, "value":item.get("filter_conditions").label.value, - }, + } if item.get("filter_conditions").label else None, }) ) logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}") @@ -147,6 +147,7 @@ async def handle_ratio_relations(relations: list[RatioRelation], session, target select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id) ) existing_paths = set(p for p in existing_path_rows.scalars().all() if p) + source_paths = set() added_count = 0 added_size = 0 @@ -164,10 +165,13 @@ async def handle_ratio_relations(relations: list[RatioRelation], session, target chosen = random.sample(files, pick_n) if pick_n < len(files) else files # Copy into target dataset with de-dup by target path - for f in chosen: - await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds) + for file in chosen: + if file.file_path in source_paths: + continue + await RatioTaskService.handle_selected_file(existing_paths, file, session, target_ds) + source_paths.add(file.file_path) added_count += 1 - added_size += int(f.file_size or 0) + added_size += int(file.file_size or 0) # Periodically flush to avoid huge transactions await session.flush() @@ -286,8 +290,15 @@ def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool: return False try: # tags could be a list of strings or list of objects with 'name' - tag_names = RatioTaskService.get_all_tags(tags) - return f"{conditions.label.label}@{conditions.label.value}" in tag_names + all_tags = RatioTaskService.get_all_tags(tags) + for tag in all_tags: + if conditions.label.label and tag.get("label") != conditions.label.label: + continue + if conditions.label.value is not None: + return True + if tag.get("value") == conditions.label.value: + return True + return False except Exception as e: logger.exception(f"Failed to get tags for {file}", e) return False @@ -295,9 +306,9 @@ def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool: return True @staticmethod - def get_all_tags(tags) -> set[str]: + def get_all_tags(tags) -> list[dict]: """获取所有处理后的标签字符串列表""" - all_tags = set() + all_tags = list() if not tags: return all_tags @@ -314,5 +325,5 @@ def get_all_tags(tags) -> set[str]: for file_tag in file_tags: for tag_data in file_tag.get_tags(): - all_tags.add(tag_data) + all_tags.append(tag_data) return all_tags