Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,63 @@ def _none_special_token_sequence(ids: List[int]) -> List[int]:
return is_prefix, (template_mismatch, retoken_mismatch, others_mismatch)


def _normalize_image_url(url: str) -> str:
if url.startswith("file://"):
path = url[len("file://") :]
try:
with open(path, "rb") as handle:
data = handle.read()
import hashlib
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statement for hashlib should be placed at the top of the file with other imports, rather than being imported locally within the function. This is inconsistent with Python best practices and the codebase's import patterns.

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback


return f"sha1:{hashlib.sha1(data).hexdigest()}"
except Exception:
return path
if url.startswith("data:"):
import base64
import hashlib
Comment on lines +87 to +88
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statements for base64 and hashlib should be placed at the top of the file with other imports, rather than being imported locally within the function. This is inconsistent with Python best practices and the codebase's import patterns.

Copilot uses AI. Check for mistakes.

header, _, payload = url.partition(",")
try:
if ";base64" in header:
data = base64.b64decode(payload)
else:
data = payload.encode("utf-8")
return f"sha1:{hashlib.sha1(data).hexdigest()}"
except Exception:
return f"{header},{payload[:64]}"
return url


def image_urls_startswith(full_urls: List[str], prefix_urls: List[str]) -> bool:
if not prefix_urls:
return True
if len(full_urls) < len(prefix_urls):
return False
norm_full = [_normalize_image_url(url) for url in full_urls]
norm_prefix = [_normalize_image_url(url) for url in prefix_urls]
return norm_full[: len(norm_prefix)] == norm_prefix


def log_image_mismatch_detail(
full_urls: List[str],
prefix_urls: List[str],
global_steps: int,
rollout_id: str,
turn_id: int,
log_dir: str | None = None,
):
if log_dir is None:
return
os.makedirs(log_dir, exist_ok=True)
with open(os.path.join(log_dir, "image_mismatch.log"), "a+") as f:
print(
"-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10,
file=f,
)
print([_normalize_image_url(u) for u in full_urls], file=f)
print([_normalize_image_url(u) for u in prefix_urls], file=f)


def log_mismatch_detail(
diagnostic: Tuple[bool, bool, bool],
full_ids: List[int],
Expand Down Expand Up @@ -913,11 +970,11 @@ def get_train_data_batch(
image_grid_thw_list.append(self._get_image_grid_thw(image_urls))

elif self.trace_aggregator.get("level", "transition") == "trajectory":
assert not self._use_mrope, "M-RoPE is not supported in trajectory level yet."

response_mask_list: List[List[int]] = []
unmerged_count: int = 0
template_mismatch_count, retoken_mismatch_count, others_mismatch_count = 0, 0, 0
image_mismatch_count = 0
response_per_turn_list: List[int] = []

for rollout_id, sample_info in finished_id_to_sample_info.items():
Expand All @@ -926,15 +983,28 @@ def get_train_data_batch(
# Identify which turns can be merged based on token ids prefix matching
current_merged_trace_idx: List[int] = []
current_context: List[int] = []
current_image_urls: List[str] = []
for turn_index, trace in enumerate(sample_info["trace_list"]):
response_per_turn_list.append(len(trace["response_ids"]))
is_prefix, diagnostic = ids_startswith(
token_prefix_ok, diagnostic = ids_startswith(
trace["prompt_ids"] + trace["response_ids"],
current_context,
self.tokenizer,
self.trace_aggregator.get("debug", False),
)
if not is_prefix and self.trace_aggregator.get("debug", False) == True:
image_prefix_ok = image_urls_startswith(trace.get("image_urls", []), current_image_urls)
if not image_prefix_ok:
image_mismatch_count += 1
if self.trace_aggregator.get("debug", False) == True:
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.trace_aggregator.get("debug", False) == True is redundant. Since the get method returns a boolean, the explicit comparison to True is unnecessary and less Pythonic. The condition should be simplified to self.trace_aggregator.get("debug", False).

Copilot uses AI. Check for mistakes.
log_image_mismatch_detail(
trace.get("image_urls", []),
current_image_urls,
global_steps,
rollout_id,
turn_index,
self.trace_aggregator.get("mismatch_log_dir", None),
)
if not token_prefix_ok and self.trace_aggregator.get("debug", False) == True:
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.trace_aggregator.get("debug", False) == True is redundant. Since the get method returns a boolean, the explicit comparison to True is unnecessary and less Pythonic. The condition should be simplified to self.trace_aggregator.get("debug", False). This is consistent with the same pattern that should be fixed on line 998.

Copilot uses AI. Check for mistakes.
template_mismatch_count += diagnostic[0]
retoken_mismatch_count += diagnostic[1]
others_mismatch_count += diagnostic[2]
Expand All @@ -948,13 +1018,15 @@ def get_train_data_batch(
self.trace_aggregator.get("mismatch_log_dir", None),
)

if is_prefix:
if token_prefix_ok and image_prefix_ok:
current_context = trace["prompt_ids"] + trace["response_ids"]
current_merged_trace_idx.append(turn_index)
current_image_urls = trace.get("image_urls", [])
else:
merged_trace_idx.append(current_merged_trace_idx)
current_merged_trace_idx = [turn_index]
current_context = trace["prompt_ids"] + trace["response_ids"]
current_image_urls = trace.get("image_urls", [])

if current_merged_trace_idx not in merged_trace_idx:
merged_trace_idx.append(current_merged_trace_idx)
Expand Down Expand Up @@ -1019,6 +1091,10 @@ def get_train_data_batch(
response_mask_list.append(one_response_mask)
data_id_list.append(sample_info["data_id"])
rollout_id_list.append(rollout_id)
if self._use_mrope:
last_trace = sample_info["trace_list"][current_merged_trace_idx[-1]]
image_urls = last_trace.get("image_urls", [])
image_grid_thw_list.append(self._get_image_grid_thw(image_urls))
# turn_index_list.append(current_merged_trace_idx)
else:
raise ValueError(f"Unknown trace_aggregator level: {self.trace_aggregator.get('level')}")
Expand Down Expand Up @@ -1115,9 +1191,11 @@ def get_train_data_batch(
"training/template_mismatch_triplets": template_mismatch_count, # type: ignore
"training/retoken_mismatch_triplets": retoken_mismatch_count, # type: ignore
"training/others_mismatch_triplets": others_mismatch_count, # type: ignore
"training/image_mismatch_triplets": image_mismatch_count, # type: ignore
"training/template_mismatch_ratio": template_mismatch_count / len(response_per_turn_list), # type: ignore
"training/retoken_mismatch_ratio": retoken_mismatch_count / len(response_per_turn_list), # type: ignore
"training/others_mismatch_ratio": others_mismatch_count / len(response_per_turn_list), # type: ignore
"training/image_mismatch_ratio": image_mismatch_count / len(response_per_turn_list), # type: ignore
}
if self.trace_aggregator.get("level", "transition") == "trajectory"
and self.trace_aggregator.get("debug", False)
Expand Down
Loading