-
Notifications
You must be signed in to change notification settings - Fork 190
fix:eagle3 offline #456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix:eagle3 offline #456
Conversation
WalkthroughRefactors speculative decoding examples: adds a disable-tqdm flag, updates README support matrix, converts TRTLLM hidden-state post-processing to async with validation and cleanup, and strengthens data-loading and conversation ID resolution and validation. Changes
Sequence Diagram(s)sequenceDiagram
participant Collector as Collector
participant TRT_dump as TRTLLM Dump File
participant PostProc as _post_process_trtllm_dumped (async)
participant Output as Conversation Output File
note over Collector,TRT_dump: Hidden states collection produces a dump file
Collector->>TRT_dump: write dump (idx N)
Collector->>PostProc: await post_process(trt_dump, conv_id)
alt file missing
PostProc-->>Collector: return False
else file present
PostProc->>PostProc: validate structure == [ { id, hidden_state } ]
PostProc->>Output: write conversation-specific output
PostProc->>TRT_dump: unlink original dump
PostProc-->>Collector: return True
end
alt success
Collector->>Collector: increment num_success and progress
else failure
Collector->>Collector: skip increment, log/warn
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #456 +/- ##
=======================================
Coverage 73.39% 73.39%
=======================================
Files 180 180
Lines 17976 17976
=======================================
+ Hits 13193 13194 +1
+ Misses 4783 4782 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
4d4c153 to
885e017
Compare
885e017 to
70d4eee
Compare
Signed-off-by: h-guo18 <[email protected]>
70d4eee to
313d3b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
examples/speculative_decoding/eagle_utils.py (1)
316-316: Remove duplicate log message.The log message "Loading input conversations..." appears twice—once at line 256 (in the offline path) and again here at line 316 (in the non-offline path). Consider removing this duplicate since the message at line 256 already covers the loading step.
Apply this diff to remove the duplicate:
- print_rank_0("Loading input conversations...") dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/speculative_decoding/README.md(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(3 hunks)examples/speculative_decoding/eagle_utils.py(3 hunks)examples/speculative_decoding/launch_train.sh(3 hunks)examples/speculative_decoding/main.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
print_rank_0(92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: linux
🔇 Additional comments (12)
examples/speculative_decoding/README.md (1)
315-315: LGTM! Documentation updated to reflect expanded model support.The support matrix now correctly includes QWen versions 1.5, 2, 2.5, and 3, aligning with the PR's goal of adding Qwen3 model support.
examples/speculative_decoding/main.py (1)
95-95: LGTM! Clean addition of progress bar control.The new
disable_tqdmfield follows the established pattern for training arguments and addresses the PR objective to prevent CI logging overflow.examples/speculative_decoding/launch_train.sh (3)
81-84: LGTM! Flag parsing follows established pattern.The
--disable_tqdmflag is parsed consistently with other command-line arguments in the script.
117-117: LGTM! Default value aligns with main.py.The default value of
Falsematches the field default inTrainingArguments.
173-173: LGTM! Flag correctly forwarded to training command.The
--disable_tqdmflag is properly injected into the accelerate launch command.examples/speculative_decoding/eagle_utils.py (3)
63-65: LGTM! Removed hard-coded role alternation check.This change enables more flexible conversation formats, including system prompts, which aligns with the PR's objective to preserve system prompts during dataset preprocessing.
256-269: LGTM! Directory support for multi-file datasets.The enhanced data loading logic now accepts a directory containing multiple
.jsonlfiles, which fulfills the PR objective. The implementation correctly handles both single files and directories with appropriate path-based branching.
286-293: LGTM! Robust conversation ID resolution with proper validation.The enhanced ID resolution now tries
conversation_id, thenuuid, thenid, and raises a descriptiveValueErrorif none are found. This aligns with the PR's objective to accept the "uuid" field and ensure an ID is present.examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (4)
211-239: LGTM! Robust async post-processing with validation.The refactored
_post_process_trtllm_dumpedfunction correctly:
- Returns early if the dump file doesn't exist (line 219-220)
- Validates the expected format with clear assertions (lines 223-230)
- Transforms the data to match HF format conventions
- Cleans up the temporary dump file (line 238)
The async implementation is appropriate for I/O-heavy operations.
247-248: LGTM! Correct async usage with success tracking.The code properly awaits the async post-processing function and converts the boolean result to an integer for the success counter.
255-274: LGTM! Fixes index mismatch for skipped conversations.The key fix here is initializing
idx = 0(line 255) and incrementing only for valid conversations (line 274). This ensures that the dump file index aligns with the actual number of processed conversations, resolving the sample index mismatch mentioned in the PR objectives when conversations are skipped.
264-264: LGTM! Full conversation preservation.Removing the
[:256]slice ensures that complete conversations are processed, with proper length validation occurring downstream at line 268.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (2)
211-239: Blocking I/O operations in async function will block the event loop.The function is declared
asyncbut performs synchronous blocking I/O operations (torch.load,torch.save,Path.unlink) without usingawait. This means these operations will block the event loop when the function is awaited, defeating the purpose of async concurrency and potentially degrading performance when processing multiple conversations.Consider one of these solutions:
Option 1 (simpler): Remove
asyncand call it synchronously since post-processing is relatively fast:- async def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): + def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): """ Post-process the TRTLLM dumped file to same format as HF dumped: 1. Remove id field, replace it with conversation_idThen update the call site:
- dump_success = await _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id) + dump_success = _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)Option 2 (proper async I/O): Use
asyncio.to_threadto run blocking I/O in a thread pool:async def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): return await asyncio.to_thread(_post_process_trtllm_dumped_sync, trtllm_dumped_file, conversation_id) def _post_process_trtllm_dumped_sync(trtllm_dumped_file: str, conversation_id: int): # existing implementation
152-153: Replace assertion with proper validation in filter function.Using
assertfor data validation in a filter function is risky. If a conversation lacks bothconversation_idanduuidfields, the assertion will raise an exception that could crash the entire filtering operation. Assertions should be used for invariants, not data validation.def keep_conversation(entry): conversation_id = entry.get("conversation_id", entry.get("uuid", None)) - assert conversation_id is not None, "conversation_id is required" + if conversation_id is None: + print(f"Warning: Skipping conversation without conversation_id or uuid") + return False output_file = args.output_dir / f"{conversation_id}.pt" return not output_file.exists()
🧹 Nitpick comments (2)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (2)
219-220: Consider logging when the TRTLLM dump file is missing.The early existence check is good defensive programming, but silently returning
Falsewhen the file doesn't exist could make debugging difficult. If TRTLLM generation fails to create the expected file, it would be helpful to log this for troubleshooting.if not trtllm_dumped_file.exists(): + print(f"Warning: TRTLLM dump file not found: {trtllm_dumped_file}") return False
238-239: Add error handling for file deletion.The
unlink()operation could raise exceptions (e.g., permission errors, filesystem issues). While unlikely, a failure here shouldn't necessarily fail the entire post-processing since the data was already successfully saved.- trtllm_dumped_file.unlink() - return True + try: + trtllm_dumped_file.unlink() + except OSError as e: + print(f"Warning: Failed to delete temporary file {trtllm_dumped_file}: {e}") + return True
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/speculative_decoding/README.md(1 hunks)examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(3 hunks)examples/speculative_decoding/eagle_utils.py(3 hunks)examples/speculative_decoding/launch_train.sh(3 hunks)examples/speculative_decoding/main.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- examples/speculative_decoding/README.md
- examples/speculative_decoding/launch_train.sh
- examples/speculative_decoding/eagle_utils.py
- examples/speculative_decoding/main.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (3)
241-249: Logic correctly tracks post-processing success.The function properly awaits the asynchronous operations and only increments
num_successwhen post-processing succeeds. The file indexing (idx + 1) correctly accounts for TRTLLM's 1-based file naming.Note: The effectiveness of this implementation depends on fixing the blocking I/O issue in
_post_process_trtllm_dumped(see earlier comment).
255-274: Index tracking fix correctly aligns with TRTLLM dump files.The explicit
idxcounter that only increments for valid conversations (line 274) is the key fix mentioned in the PR objectives. This ensures the dump file index matches the actual conversations processed, preventing index mismatches when some conversations are skipped.The removal of the
[:256]slice on line 264 is intentional and allows processing full conversations up tomax_seq_len.Optional defensive improvement: Add a safety check for
conversation_ideven though the earlier filter should have removed None values:for entry in dataset: conversation_id = entry.get("conversation_id", entry.get("uuid")) + if conversation_id is None: + num_invalid += 1 + continue conversations = entry["conversations"]
264-267: Full conversation processing is correctly implemented.Removing the
[:256]slice allows the script to process complete conversations up tomax_seq_len, as intended by the PR. The subsequent length validation (lines 268-270) ensures conversations stay within bounds.
What does this PR do?
Type of change: Bug fix
Overview:
A few minor fixes for eagle3 offline:
uuidand assert there is a conversation id; (Same as above)Usage
No change;
Testing
Qwen3-30B-A3Bfor training, export, and eval;Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Bug Fixes
Documentation