-
Notifications
You must be signed in to change notification settings - Fork 162
add end_process in deepseek ptq #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds a new helper function Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant PTQ as "PTQ Script\n(examples/deepseek/ptq.py)"
participant Worker
participant Dist as "torch.distributed"
User->>PTQ: run PTQ flow
PTQ->>Worker: perform calibration / quantization
Worker-->>PTQ: save quantization results
rect rgba(224,240,255,0.3)
note right of PTQ: call end_process()
PTQ->>PTQ: read WORLD_SIZE/RANK/LOCAL_RANK
alt world_size > 1 and dist.is_initialized()
PTQ->>Dist: destroy_process_group()
Dist-->>PTQ: resources cleaned
else single-process or not initialized
note over PTQ: no-op
end
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
examples/deepseek/ptq.py (1)
376-376
: Ensure cleanup runs on exceptions: call end_process() from finallyIf any step before the current call raises, the process group will leak.
Apply:
- args = parser.parse_args() - model = load_deepseek_model(args.config, args.model_path, args.batch_size) - tokenizer = AutoTokenizer.from_pretrained( - args.model_path, trust_remote_code=args.trust_remote_code - ) - model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size) - save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) - end_process() + args = parser.parse_args() + try: + model = load_deepseek_model(args.config, args.model_path, args.batch_size) + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, trust_remote_code=args.trust_remote_code + ) + model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size) + save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) + finally: + end_process()
1538644
to
1207c2c
Compare
Signed-off-by: bruce.xu <[email protected]>
1207c2c
to
ee2a1c7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
examples/deepseek/ptq.py (3)
293-295
: KeyError risk for LOCAL_RANK; use getenv fallback.In non‑distributed runs
LOCAL_RANK
may be unset.- if int(os.environ["LOCAL_RANK"]) == 0: + if int(os.getenv("LOCAL_RANK", "0")) == 0:
270-272
: Guard unconditional dist.barrier() calls — will raise if the default process group isn't initializedUnconditional dist.barrier() calls found; wrap each call with a guard: if dist.is_available() and dist.is_initialized(): ... (preserve any group argument, e.g. dist.barrier(group) inside the guard).
Occurrences found (update all): modelopt/torch/export/postprocess.py:567,575; modelopt/torch/export/distribute.py:149,173,258,295; examples/cnn_qat/torchvision_qat.py:176; examples/deepseek/ptq.py:271,307
- dist.barrier() + if dist.is_available() and dist.is_initialized(): + dist.barrier()
299-327
: Guard torch.distributed collectives with dist.is_initialized().dist.barrier(), dist.get_world_size(), and dist.all_gather_object() in examples/deepseek/ptq.py (around lines 299–327) are unguarded and will raise when torch.distributed isn't initialized; wrap them with if dist.is_available() and dist.is_initialized(): ... else: skip the collectives and have rank 0 write outputs directly (e.g., world_size=1 / all_quant_configs=[quant_config]).
♻️ Duplicate comments (1)
examples/deepseek/ptq.py (1)
232-239
: Tighten distributed teardown: guard + remove unused vars.Early-return when not distributed and guard with
dist.is_available()
; drop unusedrank/local_rank
.def end_process(): - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - if world_size > 1: - if dist.is_initialized(): - dist.destroy_process_group() + world_size = int(os.getenv("WORLD_SIZE", "1")) + if world_size <= 1: + return + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group()
🧹 Nitpick comments (1)
examples/deepseek/ptq.py (1)
369-377
: Ensure cleanup runs on exceptions (wrap in try/finally).This guarantees
end_process()
executes even if PTQ or saving raises.args = parser.parse_args() - model = load_deepseek_model(args.config, args.model_path, args.batch_size) - tokenizer = AutoTokenizer.from_pretrained( - args.model_path, trust_remote_code=args.trust_remote_code - ) - model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size) - save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) - end_process() + try: + model = load_deepseek_model(args.config, args.model_path, args.batch_size) + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, trust_remote_code=args.trust_remote_code + ) + model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size) + save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) + finally: + end_process()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/deepseek/ptq.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek/ptq.py (1)
modelopt/torch/utils/distributed.py (1)
world_size
(204-206)
@sugunav14 hello, please help review it |
What does this PR do?
when we run ptq, we see warning say that lack of dist.destroy_process, so I add it
Additional Information
Summary by CodeRabbit