1717
1818import pytest
1919import torch
20- import torch .nn as nn
2120from _test_utils .import_helper import skip_if_no_megatron
2221from _test_utils .torch_dist .dist_utils import spawn_multiprocess_job
2322from _test_utils .torch_dist .plugins .megatron_common import (
@@ -374,8 +373,10 @@ def test_fp8_real_quantize():
374373
375374def _test_kv_cache_quant_helper (config , rank , size ):
376375 """Helper function for testing KV cache quantization with TEDotProductAttention."""
377- initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED )
378-
376+ initialize_for_megatron (
377+ tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED
378+ )
379+
379380 # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
380381 # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
381382 model = get_mcore_gpt_model (
@@ -386,43 +387,45 @@ def _test_kv_cache_quant_helper(config, rank, size):
386387 vocab_size = 32 ,
387388 transformer_impl = "modelopt" , # This uses TEDotProductAttention via get_gpt_modelopt_spec
388389 ).cuda ()
389-
390+
390391 # Create dummy input for calibration
391392 prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
392-
393+
393394 def forward_fn (model ):
394395 return megatron_prefill (model , prompt_tokens )
395-
396+
396397 # Test KV cache quantization with the given config
397398 quantized_model = mtq .quantize (model , config , forward_fn )
398-
399+
399400 # Find TEDotProductAttention modules and verify they have KV cache quantizers
400401 te_attention_found = False
401402 for name , module in quantized_model .named_modules ():
402403 # Check if this is a quantized TEDotProductAttention
403- if hasattr (module , ' q_bmm_quantizer' ) and hasattr (module , ' k_bmm_quantizer' ):
404+ if hasattr (module , " q_bmm_quantizer" ) and hasattr (module , " k_bmm_quantizer" ):
404405 te_attention_found = True
405406 # Verify all expected quantizers exist
406- assert hasattr (module , ' v_bmm_quantizer' ), f"Missing v_bmm_quantizer in { name } "
407-
407+ assert hasattr (module , " v_bmm_quantizer" ), f"Missing v_bmm_quantizer in { name } "
408+
408409 # Verify K and V quantizers are enabled (main purpose of KV cache configs)
409410 assert module .k_bmm_quantizer .is_enabled , f"K quantizer not enabled in { name } "
410411 assert module .v_bmm_quantizer .is_enabled , f"V quantizer not enabled in { name } "
411-
412+
412413 assert te_attention_found , "No TEDotProductAttention with KV cache quantizers found in model"
413-
414+
414415 # Quick smoke test that forward still works
415416 output = forward_fn (quantized_model )
416417 assert output is not None , "Forward pass failed"
417-
418+
418419
419420def _test_kv_cache_sharded_state_dict_helper (tmp_path , config , rank , size ):
420421 """Helper for testing KV cache quantization with sharded state dict save/load."""
421422 # Disable output_layer quantization (same as other sharded state dict tests)
422423 config ["quant_cfg" ]["*output_layer*" ] = {"enable" : False }
423-
424- initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED )
425-
424+
425+ initialize_for_megatron (
426+ tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED
427+ )
428+
426429 # Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
427430 model_ref = get_mcore_gpt_model (
428431 tensor_model_parallel_size = size ,
@@ -432,7 +435,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
432435 vocab_size = 64 ,
433436 transformer_impl = "modelopt" , # CRITICAL: Use TEDotProductAttention
434437 ).cuda ()
435-
438+
436439 model_test = get_mcore_gpt_model (
437440 tensor_model_parallel_size = size ,
438441 num_layers = 2 ,
@@ -441,29 +444,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
441444 vocab_size = 64 ,
442445 transformer_impl = "modelopt" ,
443446 ).cuda ()
444-
445- prompt_tokens = torch .randint (0 , model_ref .vocab_size , (2 , model_ref .max_sequence_length )).cuda ()
446-
447+
448+ prompt_tokens = torch .randint (
449+ 0 , model_ref .vocab_size , (2 , model_ref .max_sequence_length )
450+ ).cuda ()
451+
447452 def forward_fn (model ):
448453 return megatron_prefill (model , prompt_tokens )
449-
454+
450455 # Quantize the reference model
451456 model_ref = mtq .quantize (model_ref , config , forward_fn )
452-
457+
453458 # CRITICAL: model_test must also be quantized with the same config
454459 # Otherwise it won't have the KV cache quantizer keys when loading state dict
455460 model_test = mtq .quantize (model_test , config , forward_fn )
456-
461+
457462 # Verify KV cache quantizers were created
458463 kv_quantizers_found = False
459464 for name , module in model_ref .named_modules ():
460- if hasattr (module , ' k_bmm_quantizer' ) and hasattr (module , ' v_bmm_quantizer' ):
465+ if hasattr (module , " k_bmm_quantizer" ) and hasattr (module , " v_bmm_quantizer" ):
461466 kv_quantizers_found = True
462467 assert module .k_bmm_quantizer .is_enabled , f"K quantizer not enabled in { name } "
463468 assert module .v_bmm_quantizer .is_enabled , f"V quantizer not enabled in { name } "
464-
469+
465470 assert kv_quantizers_found , "No KV cache quantizers found in quantized model"
466-
471+
467472 # Test sharded state dict save/load
468473 sharded_state_dict_test_helper (
469474 tmp_path ,
@@ -473,32 +478,38 @@ def forward_fn(model):
473478 meta_device = False ,
474479 version = None ,
475480 )
476-
481+
477482 # Verify KV cache quantizers are restored correctly in model_test
478483 for (name_ref , module_ref ), (name_test , module_test ) in zip (
479484 model_ref .named_modules (), model_test .named_modules ()
480485 ):
481- if hasattr (module_ref , 'k_bmm_quantizer' ):
482- assert hasattr (module_test , 'k_bmm_quantizer' ), f"K quantizer missing after restore in { name_test } "
483- assert hasattr (module_test , 'v_bmm_quantizer' ), f"V quantizer missing after restore in { name_test } "
484-
486+ if hasattr (module_ref , "k_bmm_quantizer" ):
487+ assert hasattr (module_test , "k_bmm_quantizer" ), (
488+ f"K quantizer missing after restore in { name_test } "
489+ )
490+ assert hasattr (module_test , "v_bmm_quantizer" ), (
491+ f"V quantizer missing after restore in { name_test } "
492+ )
493+
485494 # Check that quantizer states match
486- if hasattr (module_ref .k_bmm_quantizer , '_amax' ):
487- assert hasattr (module_test .k_bmm_quantizer , '_amax' ), f"K quantizer _amax missing in { name_test } "
495+ if hasattr (module_ref .k_bmm_quantizer , "_amax" ):
496+ assert hasattr (module_test .k_bmm_quantizer , "_amax" ), (
497+ f"K quantizer _amax missing in { name_test } "
498+ )
488499 if module_ref .k_bmm_quantizer ._amax is not None :
489500 assert torch .allclose (
490- module_ref .k_bmm_quantizer ._amax ,
491- module_test .k_bmm_quantizer ._amax
501+ module_ref .k_bmm_quantizer ._amax , module_test .k_bmm_quantizer ._amax
492502 ), f"K quantizer _amax mismatch in { name_test } "
493-
494- if hasattr (module_ref .v_bmm_quantizer , '_amax' ):
495- assert hasattr (module_test .v_bmm_quantizer , '_amax' ), f"V quantizer _amax missing in { name_test } "
503+
504+ if hasattr (module_ref .v_bmm_quantizer , "_amax" ):
505+ assert hasattr (module_test .v_bmm_quantizer , "_amax" ), (
506+ f"V quantizer _amax missing in { name_test } "
507+ )
496508 if module_ref .v_bmm_quantizer ._amax is not None :
497509 assert torch .allclose (
498- module_ref .v_bmm_quantizer ._amax ,
499- module_test .v_bmm_quantizer ._amax
510+ module_ref .v_bmm_quantizer ._amax , module_test .v_bmm_quantizer ._amax
500511 ), f"V quantizer _amax mismatch in { name_test } "
501-
512+
502513
503514@pytest .mark .parametrize (
504515 "config" ,
@@ -509,16 +520,14 @@ def forward_fn(model):
509520)
510521def test_kv_cache_quant (config ):
511522 """Verify KV cache quantization works correctly with TEDotProductAttention.
512-
513- This test ensures TEDotProductAttention is properly registered and gets the
523+
524+ This test ensures TEDotProductAttention is properly registered and gets the
514525 expected q/k/v_bmm_quantizers when using KV cache configs.
515-
526+
516527 Note: This test requires Transformer Engine to be installed since TEDotProductAttention
517528 is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
518529 """
519- spawn_multiprocess_job (
520- size = 1 , job = partial (_test_kv_cache_quant_helper , config ), backend = "nccl"
521- )
530+ spawn_multiprocess_job (size = 1 , job = partial (_test_kv_cache_quant_helper , config ), backend = "nccl" )
522531
523532
524533@pytest .mark .parametrize (
@@ -530,7 +539,7 @@ def test_kv_cache_quant(config):
530539)
531540def test_kv_cache_sharded_state_dict (tmp_path , config ):
532541 """Test KV cache quantization with sharded state dict save/load.
533-
542+
534543 This test verifies the complete workflow of saving and loading KV cache quantized
535544 models with distributed checkpointing, ensuring quantizer states are properly
536545 preserved across the save/load cycle.
@@ -539,5 +548,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config):
539548 spawn_multiprocess_job (
540549 size = size ,
541550 job = partial (_test_kv_cache_sharded_state_dict_helper , tmp_path , config ),
542- backend = "nccl"
551+ backend = "nccl" ,
543552 )
0 commit comments