|
30 | 30 | get_processor, |
31 | 31 | get_tokenizer, |
32 | 32 | is_enc_dec, |
| 33 | + is_nemotron_vl_model, |
33 | 34 | ) |
34 | 35 | from transformers import ( |
35 | 36 | AutoConfig, |
@@ -284,8 +285,8 @@ def main(args): |
284 | 285 |
|
285 | 286 | full_model = model |
286 | 287 |
|
287 | | - # Detect if this is a Nemotron VL model using model-based detection |
288 | | - is_nemotron_vl = is_multimodal_model(full_model) and "nemotron" in args.pyt_ckpt_path.lower() |
| 288 | + # Detect if this is a Nemotron VL model using architecture-based detection |
| 289 | + is_nemotron_vl = is_nemotron_vl_model(full_model) |
289 | 290 |
|
290 | 291 | if model_type == "mllama": |
291 | 292 | processor = get_processor( |
@@ -470,59 +471,36 @@ def main(args): |
470 | 471 | "input_features" if model_type == "whisper" else "input_ids" |
471 | 472 | ][0:1] |
472 | 473 |
|
473 | | - # For Nemotron VL models, try text-only generation first, then VL generation as additional test |
474 | | - if is_nemotron_vl: |
| 474 | + # Generate preview before quantization |
| 475 | + if is_nemotron_vl and tokenizer is not None: |
475 | 476 | print("Running text-only preview generation for Nemotron VL model...") |
476 | | - try: |
477 | | - # Try text-only generation using helper function that supports both v1 and v2 |
478 | | - if tokenizer is None: |
479 | | - raise ValueError("Tokenizer is required for Nemotron VL text generation") |
480 | | - |
481 | | - question = tokenizer.decode(input_ids[0], skip_special_tokens=True) |
482 | | - generation_config = { |
483 | | - "max_new_tokens": 100, |
484 | | - "do_sample": False, |
485 | | - "eos_token_id": tokenizer.eos_token_id, |
486 | | - } |
487 | | - |
488 | | - # Use helper function that supports both v1 and v2 models |
489 | | - text_response = run_text_only_generation( |
490 | | - full_model, tokenizer, question, generation_config, args.pyt_ckpt_path |
491 | | - ) |
492 | | - |
493 | | - if text_response is not None: |
494 | | - generated_ids_before_ptq = text_response # Store text response |
495 | | - print(f"✅ Text-only generation successful: {text_response[:100]}...") |
496 | | - else: |
497 | | - raise Exception("Text-only generation returned None") |
| 477 | + question = tokenizer.decode(input_ids[0], skip_special_tokens=True) |
| 478 | + generation_config = { |
| 479 | + "max_new_tokens": 100, |
| 480 | + "do_sample": False, |
| 481 | + "eos_token_id": tokenizer.eos_token_id, |
| 482 | + } |
| 483 | + |
| 484 | + # Try text-only generation first, fall back to standard generate |
| 485 | + text_response = run_text_only_generation( |
| 486 | + full_model, tokenizer, question, generation_config, args.pyt_ckpt_path |
| 487 | + ) |
498 | 488 |
|
499 | | - except Exception as e: |
500 | | - print(f"Text-only generation failed: {e}") |
501 | | - print("Falling back to standard generate() method...") |
502 | | - try: |
503 | | - generated_ids_before_ptq = full_model.generate( |
504 | | - input_ids, max_new_tokens=100 |
505 | | - ) |
506 | | - except Exception as e2: |
507 | | - print(f"Standard generation also failed: {e2}") |
508 | | - generated_ids_before_ptq = None |
| 489 | + if text_response is not None: |
| 490 | + generated_ids_before_ptq = text_response |
| 491 | + print(f"✅ Text-only generation successful: {text_response[:100]}...") |
| 492 | + else: |
| 493 | + print("Text-only generation failed, falling back to standard generate...") |
| 494 | + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) |
509 | 495 |
|
510 | 496 | # Run additional VL test with images |
511 | 497 | print("Running additional VL test with images...") |
512 | 498 | run_vl_preview_generation( |
513 | 499 | full_model, tokenizer, args.pyt_ckpt_path, "before quantization (VL test)" |
514 | 500 | ) |
515 | | - |
516 | 501 | else: |
517 | | - try: |
518 | | - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) |
519 | | - except Exception as e: |
520 | | - print( |
521 | | - "Error during model generation. Please check if your transformers version is " |
522 | | - "compatible with the model." |
523 | | - ) |
524 | | - print(f"Error details: {e}") |
525 | | - raise |
| 502 | + # Standard generation for non-Nemotron VL models |
| 503 | + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) |
526 | 504 | if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": |
527 | 505 | print("Applying nvfp4 quantization (MoE only) for gpt-oss") |
528 | 506 |
|
|
0 commit comments