-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Support of trajectory aggregation for mrope multimodal model, and add multimodal prefix checks for trajectory merge #469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,63 @@ def _none_special_token_sequence(ids: List[int]) -> List[int]: | |
| return is_prefix, (template_mismatch, retoken_mismatch, others_mismatch) | ||
|
|
||
|
|
||
| def _normalize_image_url(url: str) -> str: | ||
| if url.startswith("file://"): | ||
| path = url[len("file://") :] | ||
| try: | ||
| with open(path, "rb") as handle: | ||
| data = handle.read() | ||
| import hashlib | ||
|
|
||
| return f"sha1:{hashlib.sha1(data).hexdigest()}" | ||
| except Exception: | ||
| return path | ||
| if url.startswith("data:"): | ||
| import base64 | ||
| import hashlib | ||
|
Comment on lines
+87
to
+88
|
||
|
|
||
| header, _, payload = url.partition(",") | ||
| try: | ||
| if ";base64" in header: | ||
| data = base64.b64decode(payload) | ||
| else: | ||
| data = payload.encode("utf-8") | ||
| return f"sha1:{hashlib.sha1(data).hexdigest()}" | ||
| except Exception: | ||
| return f"{header},{payload[:64]}" | ||
| return url | ||
|
|
||
|
|
||
| def image_urls_startswith(full_urls: List[str], prefix_urls: List[str]) -> bool: | ||
| if not prefix_urls: | ||
| return True | ||
| if len(full_urls) < len(prefix_urls): | ||
| return False | ||
| norm_full = [_normalize_image_url(url) for url in full_urls] | ||
| norm_prefix = [_normalize_image_url(url) for url in prefix_urls] | ||
| return norm_full[: len(norm_prefix)] == norm_prefix | ||
|
|
||
|
|
||
| def log_image_mismatch_detail( | ||
| full_urls: List[str], | ||
| prefix_urls: List[str], | ||
| global_steps: int, | ||
| rollout_id: str, | ||
| turn_id: int, | ||
| log_dir: str | None = None, | ||
| ): | ||
| if log_dir is None: | ||
| return | ||
| os.makedirs(log_dir, exist_ok=True) | ||
| with open(os.path.join(log_dir, "image_mismatch.log"), "a+") as f: | ||
| print( | ||
| "-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10, | ||
| file=f, | ||
| ) | ||
| print([_normalize_image_url(u) for u in full_urls], file=f) | ||
| print([_normalize_image_url(u) for u in prefix_urls], file=f) | ||
|
|
||
|
|
||
| def log_mismatch_detail( | ||
| diagnostic: Tuple[bool, bool, bool], | ||
| full_ids: List[int], | ||
|
|
@@ -913,11 +970,11 @@ def get_train_data_batch( | |
| image_grid_thw_list.append(self._get_image_grid_thw(image_urls)) | ||
|
|
||
| elif self.trace_aggregator.get("level", "transition") == "trajectory": | ||
| assert not self._use_mrope, "M-RoPE is not supported in trajectory level yet." | ||
|
|
||
| response_mask_list: List[List[int]] = [] | ||
| unmerged_count: int = 0 | ||
| template_mismatch_count, retoken_mismatch_count, others_mismatch_count = 0, 0, 0 | ||
| image_mismatch_count = 0 | ||
| response_per_turn_list: List[int] = [] | ||
|
|
||
| for rollout_id, sample_info in finished_id_to_sample_info.items(): | ||
|
|
@@ -926,15 +983,28 @@ def get_train_data_batch( | |
| # Identify which turns can be merged based on token ids prefix matching | ||
| current_merged_trace_idx: List[int] = [] | ||
| current_context: List[int] = [] | ||
| current_image_urls: List[str] = [] | ||
| for turn_index, trace in enumerate(sample_info["trace_list"]): | ||
| response_per_turn_list.append(len(trace["response_ids"])) | ||
| is_prefix, diagnostic = ids_startswith( | ||
| token_prefix_ok, diagnostic = ids_startswith( | ||
| trace["prompt_ids"] + trace["response_ids"], | ||
| current_context, | ||
| self.tokenizer, | ||
| self.trace_aggregator.get("debug", False), | ||
| ) | ||
| if not is_prefix and self.trace_aggregator.get("debug", False) == True: | ||
| image_prefix_ok = image_urls_startswith(trace.get("image_urls", []), current_image_urls) | ||
| if not image_prefix_ok: | ||
| image_mismatch_count += 1 | ||
| if self.trace_aggregator.get("debug", False) == True: | ||
|
||
| log_image_mismatch_detail( | ||
| trace.get("image_urls", []), | ||
| current_image_urls, | ||
| global_steps, | ||
| rollout_id, | ||
| turn_index, | ||
| self.trace_aggregator.get("mismatch_log_dir", None), | ||
| ) | ||
| if not token_prefix_ok and self.trace_aggregator.get("debug", False) == True: | ||
|
||
| template_mismatch_count += diagnostic[0] | ||
| retoken_mismatch_count += diagnostic[1] | ||
| others_mismatch_count += diagnostic[2] | ||
|
|
@@ -948,13 +1018,15 @@ def get_train_data_batch( | |
| self.trace_aggregator.get("mismatch_log_dir", None), | ||
| ) | ||
|
|
||
| if is_prefix: | ||
| if token_prefix_ok and image_prefix_ok: | ||
| current_context = trace["prompt_ids"] + trace["response_ids"] | ||
| current_merged_trace_idx.append(turn_index) | ||
| current_image_urls = trace.get("image_urls", []) | ||
| else: | ||
| merged_trace_idx.append(current_merged_trace_idx) | ||
| current_merged_trace_idx = [turn_index] | ||
| current_context = trace["prompt_ids"] + trace["response_ids"] | ||
| current_image_urls = trace.get("image_urls", []) | ||
|
|
||
| if current_merged_trace_idx not in merged_trace_idx: | ||
| merged_trace_idx.append(current_merged_trace_idx) | ||
|
|
@@ -1019,6 +1091,10 @@ def get_train_data_batch( | |
| response_mask_list.append(one_response_mask) | ||
| data_id_list.append(sample_info["data_id"]) | ||
| rollout_id_list.append(rollout_id) | ||
| if self._use_mrope: | ||
| last_trace = sample_info["trace_list"][current_merged_trace_idx[-1]] | ||
| image_urls = last_trace.get("image_urls", []) | ||
| image_grid_thw_list.append(self._get_image_grid_thw(image_urls)) | ||
| # turn_index_list.append(current_merged_trace_idx) | ||
| else: | ||
| raise ValueError(f"Unknown trace_aggregator level: {self.trace_aggregator.get('level')}") | ||
|
|
@@ -1115,9 +1191,11 @@ def get_train_data_batch( | |
| "training/template_mismatch_triplets": template_mismatch_count, # type: ignore | ||
| "training/retoken_mismatch_triplets": retoken_mismatch_count, # type: ignore | ||
| "training/others_mismatch_triplets": others_mismatch_count, # type: ignore | ||
| "training/image_mismatch_triplets": image_mismatch_count, # type: ignore | ||
| "training/template_mismatch_ratio": template_mismatch_count / len(response_per_turn_list), # type: ignore | ||
| "training/retoken_mismatch_ratio": retoken_mismatch_count / len(response_per_turn_list), # type: ignore | ||
| "training/others_mismatch_ratio": others_mismatch_count / len(response_per_turn_list), # type: ignore | ||
| "training/image_mismatch_ratio": image_mismatch_count / len(response_per_turn_list), # type: ignore | ||
| } | ||
| if self.trace_aggregator.get("level", "transition") == "trajectory" | ||
| and self.trace_aggregator.get("debug", False) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import statement for hashlib should be placed at the top of the file with other imports, rather than being imported locally within the function. This is inconsistent with Python best practices and the codebase's import patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback