Skip to content

Commit ee7588a

Browse files
committed
[5541172] update based on review comments
Signed-off-by: unknown <[email protected]>
1 parent 091116d commit ee7588a

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,13 @@ def validate_inputs(hf_model, ep_path_pairs):
496496
Returns:
497497
bool: True if all inputs are valid, False otherwise.
498498
"""
499-
# Check HF model path (only if provided)
500-
if hf_model and not os.path.exists(hf_model):
501-
print(f"[ERROR] Hugging Face model path does not exist: {hf_model}")
502-
return False
499+
# Check HF model path (only if provided and it looks like a local path)
500+
# If it doesn't exist locally, assume it's a HF model name to be downloaded
501+
if hf_model and os.path.exists(hf_model):
502+
# Verify it's a valid directory
503+
if not os.path.isdir(hf_model):
504+
print(f"[ERROR] Hugging Face model path is not a directory: {hf_model}")
505+
return False
503506

504507
# Check execution providers and paths
505508
for ep, path in ep_path_pairs:
@@ -525,6 +528,10 @@ def main():
525528
python compute_kl_divergence.py --hf_model "F:\\shared\\Llama-3.1-8B-Instruct"
526529
--ep cuda --path "G:\\models\\cuda_model" --output "hf_vs_cuda.json"
527530
531+
# Compare HF vs CUDA model (download from Hugging Face)
532+
python compute_kl_divergence.py --hf_model "meta-llama/Llama-3.1-8B-Instruct"
533+
--ep cuda --path "G:\\models\\cuda_model" --output "hf_vs_cuda.json"
534+
528535
# Compare HF vs CUDA vs DirectML models
529536
python compute_kl_divergence.py --hf_model "F:\\shared\\Llama-3.1-8B-Instruct"
530537
--ep cuda --path "G:\\models\\cuda_model"

examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,7 @@ def main():
182182
global DEBUG
183183
DEBUG = args.debug
184184

185-
# Validate model directory exists
186-
if not os.path.exists(args.model_path):
187-
print(f"Error: Model directory does not exist: {args.model_path}")
188-
return 1
185+
189186

190187
try:
191188
# Extract logits

0 commit comments

Comments
 (0)