Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_tags(self) -> List[str]:
tags.append(tag_values)
# 如果 from_name 不为空,添加前缀
if self.from_name:
tags = [f"{self.from_name}@{tag}" for tag in tags]
tags = [{"label": self.from_name, "value": tag} for tag in tags]
return tags


Expand Down
29 changes: 20 additions & 9 deletions runtime/datamate-python/app/module/ratio/service/ratio_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def create_task(
'label': {
"label":item.get("filter_conditions").label.label,
"value":item.get("filter_conditions").label.value,
},
} if item.get("filter_conditions").label else None,
})
)
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
Expand Down Expand Up @@ -147,6 +147,7 @@ async def handle_ratio_relations(relations: list[RatioRelation], session, target
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
source_paths = set()

added_count = 0
added_size = 0
Expand All @@ -164,10 +165,13 @@ async def handle_ratio_relations(relations: list[RatioRelation], session, target
chosen = random.sample(files, pick_n) if pick_n < len(files) else files

# Copy into target dataset with de-dup by target path
for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
for file in chosen:
if file.file_path in source_paths:
continue
await RatioTaskService.handle_selected_file(existing_paths, file, session, target_ds)
source_paths.add(file.file_path)
added_count += 1
added_size += int(f.file_size or 0)
added_size += int(file.file_size or 0)

# Periodically flush to avoid huge transactions
await session.flush()
Expand Down Expand Up @@ -286,18 +290,25 @@ def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
all_tags = RatioTaskService.get_all_tags(tags)
for tag in all_tags:
if conditions.label.label and tag.get("label") != conditions.label.label:
continue
if conditions.label.value is not None:
return True
if tag.get("value") == conditions.label.value:
return True
return False
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False

return True

@staticmethod
def get_all_tags(tags) -> set[str]:
def get_all_tags(tags) -> list[dict]:
"""获取所有处理后的标签字符串列表"""
all_tags = set()
all_tags = list()
if not tags:
return all_tags

Expand All @@ -314,5 +325,5 @@ def get_all_tags(tags) -> set[str]:

for file_tag in file_tags:
for tag_data in file_tag.get_tags():
all_tags.add(tag_data)
all_tags.append(tag_data)
return all_tags