Skip to content

Comments

Data Inspection Tool#2904

Open
Helw150 wants to merge 3 commits intomainfrom
will/inspect-debug-data
Open

Data Inspection Tool#2904
Helw150 wants to merge 3 commits intomainfrom
will/inspect-debug-data

Conversation

@Helw150
Copy link
Member

@Helw150 Helw150 commented Feb 20, 2026

  • Add scripts/training/inspect_data.py — a CLI tool to inspect decoded training data at a given step on a Ray cluster, useful for debugging data pipelines, mixture weights, and masking
  # Dump a single step's batch as JSONL for agents
  uv run scripts/training/inspect_data.py experiments/references/canary_train.py --step 100 --cluster us-central2

  # Per-step summary stats across multiple steps for agents
  uv run scripts/training/inspect_data.py experiments/references/canary_train.py --steps 0,100,500 --cluster us-central2 --summary

  # Interactive TUI browser for people
  uv run scripts/training/inspect_data.py experiments/references/canary_train.py --step 100 --cluster us-central2 --tui

@claude
Copy link
Contributor

claude bot commented Feb 20, 2026

Claude finished @Helw150's task in 2m 1s —— View job


Review: Data Inspection Tool

  • Read changed file and gather context
  • Verify internal API usage against levanter/marin
  • Post review

This PR adds a CLI tool (scripts/debug/inspect_data.py) that decodes and dumps training batches at a given step, supporting JSONL output, summary stats, and an interactive TUI. The internal API usage against MixtureDataset, BatchSchedule, and TrainLmOnPodConfig is correct.

Specification check: This PR is 587 lines with no linked specification or design doc. Per docs/recipes/agent-coding.md, a specification is expected for PRs of this size.

Bugs

  • Broken Ray job submission path (scripts/debug/inspect_data.py:87): The entrypoint is hardcoded as scripts/training/inspect_data.py, but the file actually lives at scripts/debug/inspect_data.py. When --cluster is used (the primary use case), the submitted Ray job will fail because the entrypoint path doesn't exist. The docstring usage examples (lines 10-17) have the same stale path. Fix this →

  • TUI Page Down scroll calculation is inconsistent (scripts/debug/inspect_data.py:410-412): Page Down estimates total lines as len(text.split("\n")) * 2, but draw() computes wrapped lines properly via textwrap.wrap. This means Page Down can overshoot or undershoot the actual content length. The draw() function already computes wrapped correctly — Page Down/Up should use the same wrapping logic or cache the wrapped line count.

Guidelines

  • The code accesses several internal/private methods of MixtureDataset (_get_block, _get_stage_for_block). These are implementation details that could change without notice. Consider whether MixtureDataset should expose a small public API surface for inspection use cases, or at minimum add a comment acknowledging the coupling.


config = _resolve_cluster_config(cluster)

parts = ["python", "scripts/training/inspect_data.py", experiment]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: This path doesn't match the actual file location. The file is at scripts/debug/inspect_data.py, so the Ray job will fail to find the entrypoint.

Suggested change
parts = ["python", "scripts/training/inspect_data.py", experiment]
parts = ["python", "scripts/debug/inspect_data.py", experiment]

Comment on lines 410 to 412
text = examples[doc_idx]["text"]
total = len(text.split("\n")) * 2
scroll_offset = min(scroll_offset + body_h, max(0, total))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rough estimate (newlines * 2) diverges from the actual wrapping logic in draw() which uses textwrap.wrap(line, w - 1). Page Down can overshoot or undershoot the real content. Consider computing the wrapped line count from the same textwrap logic used in draw, or caching it.

Copy link

@XenonMolecule XenonMolecule left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is great!!! And I LOVE the interactive viewer. I had a few errors when I ran the non-interactive mode on my data, so just wanted to flag the changes I had to make to get it to work, but all around was a super useful tool!


# When submitted as a Ray job, the script runs on the cluster without --cluster.
# Detect this via RAY_JOB_ID which Ray sets automatically for submitted jobs.
on_cluster_node = os.environ.get("RAY_JOB_ID") is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always set when we run on the cluster? When I launched with

uv run scripts/debug/inspect_data.py experiments/exp_qwen3_0_6b_rephraser_sft.py \
    --step 736 --cluster us-central1 -o step_736.jsonl

this line was causing me trouble. I fixed by moving to

on_cluster_node = os.environ.get("RAY_JOB_ID") is not None or os.environ.get("MARIN_PREFIX") is not None

for idx in indices:
block_id = idx // dataset.block_size
index_within_block = idx % dataset.block_size
block = dataset._get_block(block_id)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be:

blocking_wait(dataset._get_block(block_id)) instead of dataset._get_block(block_id)

for i, (ex, src) in enumerate(zip(examples, sources, strict=True)):
tokens = ex.tokens.tolist()
lw = ex.loss_weight
pct_masked = float((lw == 0).sum()) / len(lw) * 100

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tokens = ex.tokens.array.tolist()
lw = ex.loss_weight.array
pct_masked = float((lw == 0).sum()) / lw.size * 100

At least when I ran, LmExample.tokens and LmExample.loss_weight are Haliax NamedArrays, not raw JAX/numpy arrays. I needed .array to unwrap to the underlying JAX array before calling .tolist(). And JAX arrays don't support len() so I used .size instead.

out.close()

if output:
click.echo(f"Wrote {total_examples} examples to {output}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider:

gcs_output = os.path.join(step_output_path, "debug", output)
fs = gcsfs.GCSFileSystem()
fs.put(output, gcs_output)
click.echo(f"Wrote {total_examples} examples to {gcs_output}")
click.echo(f"Download with: gsutil cp {gcs_output} .")

The output file was previously written to the Ray worker's ephemeral local disk and lost after the job exited. This uploads to GCS under {step_output_path}/debug/, co-located with the training run's artifacts and collision-free via Marin's {name}-{hash} path convention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants