Skip to content

Commit 8f52995

Browse files
authored
Fix ratio (#162)
* fix: fixed the issue where an error would be reported when only setting the proportioning quantity when creating a proportioning task * fix: prevent adding the same file multiple times * fix: implement a more flexible matching strategy, allowing only the tag name to be configured for matching
1 parent bb8641b commit 8f52995

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

runtime/datamate-python/app/module/dataset/schema/dataset_file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_tags(self) -> List[str]:
4646
tags.append(tag_values)
4747
# 如果 from_name 不为空,添加前缀
4848
if self.from_name:
49-
tags = [f"{self.from_name}@{tag}" for tag in tags]
49+
tags = [{"label": self.from_name, "value": tag} for tag in tags]
5050
return tags
5151

5252

runtime/datamate-python/app/module/ratio/service/ratio_task.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def create_task(
6161
'label': {
6262
"label":item.get("filter_conditions").label.label,
6363
"value":item.get("filter_conditions").label.value,
64-
},
64+
} if item.get("filter_conditions").label else None,
6565
})
6666
)
6767
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
@@ -147,6 +147,7 @@ async def handle_ratio_relations(relations: list[RatioRelation], session, target
147147
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
148148
)
149149
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
150+
source_paths = set()
150151

151152
added_count = 0
152153
added_size = 0
@@ -164,10 +165,13 @@ async def handle_ratio_relations(relations: list[RatioRelation], session, target
164165
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
165166

166167
# Copy into target dataset with de-dup by target path
167-
for f in chosen:
168-
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
168+
for file in chosen:
169+
if file.file_path in source_paths:
170+
continue
171+
await RatioTaskService.handle_selected_file(existing_paths, file, session, target_ds)
172+
source_paths.add(file.file_path)
169173
added_count += 1
170-
added_size += int(f.file_size or 0)
174+
added_size += int(file.file_size or 0)
171175

172176
# Periodically flush to avoid huge transactions
173177
await session.flush()
@@ -286,18 +290,25 @@ def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
286290
return False
287291
try:
288292
# tags could be a list of strings or list of objects with 'name'
289-
tag_names = RatioTaskService.get_all_tags(tags)
290-
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
293+
all_tags = RatioTaskService.get_all_tags(tags)
294+
for tag in all_tags:
295+
if conditions.label.label and tag.get("label") != conditions.label.label:
296+
continue
297+
if conditions.label.value is not None:
298+
return True
299+
if tag.get("value") == conditions.label.value:
300+
return True
301+
return False
291302
except Exception as e:
292303
logger.exception(f"Failed to get tags for {file}", e)
293304
return False
294305

295306
return True
296307

297308
@staticmethod
298-
def get_all_tags(tags) -> set[str]:
309+
def get_all_tags(tags) -> list[dict]:
299310
"""获取所有处理后的标签字符串列表"""
300-
all_tags = set()
311+
all_tags = list()
301312
if not tags:
302313
return all_tags
303314

@@ -314,5 +325,5 @@ def get_all_tags(tags) -> set[str]:
314325

315326
for file_tag in file_tags:
316327
for tag_data in file_tag.get_tags():
317-
all_tags.add(tag_data)
328+
all_tags.append(tag_data)
318329
return all_tags

0 commit comments

Comments
 (0)