1717
1818import pytest
1919import torch
20- import torch .nn as nn
2120from _test_utils .import_helper import skip_if_no_megatron
2221from _test_utils .torch_dist .dist_utils import spawn_multiprocess_job
2322from _test_utils .torch_dist .plugins .megatron_common import (
@@ -368,8 +367,10 @@ def test_fp8_real_quantize():
368367
369368def _test_kv_cache_quant_helper (config , rank , size ):
370369 """Helper function for testing KV cache quantization with TEDotProductAttention."""
371- initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED )
372-
370+ initialize_for_megatron (
371+ tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED
372+ )
373+
373374 # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
374375 # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
375376 model = get_mcore_gpt_model (
@@ -380,43 +381,45 @@ def _test_kv_cache_quant_helper(config, rank, size):
380381 vocab_size = 32 ,
381382 transformer_impl = "modelopt" , # This uses TEDotProductAttention via get_gpt_modelopt_spec
382383 ).cuda ()
383-
384+
384385 # Create dummy input for calibration
385386 prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
386-
387+
387388 def forward_fn (model ):
388389 return megatron_prefill (model , prompt_tokens )
389-
390+
390391 # Test KV cache quantization with the given config
391392 quantized_model = mtq .quantize (model , config , forward_fn )
392-
393+
393394 # Find TEDotProductAttention modules and verify they have KV cache quantizers
394395 te_attention_found = False
395396 for name , module in quantized_model .named_modules ():
396397 # Check if this is a quantized TEDotProductAttention
397- if hasattr (module , ' q_bmm_quantizer' ) and hasattr (module , ' k_bmm_quantizer' ):
398+ if hasattr (module , " q_bmm_quantizer" ) and hasattr (module , " k_bmm_quantizer" ):
398399 te_attention_found = True
399400 # Verify all expected quantizers exist
400- assert hasattr (module , ' v_bmm_quantizer' ), f"Missing v_bmm_quantizer in { name } "
401-
401+ assert hasattr (module , " v_bmm_quantizer" ), f"Missing v_bmm_quantizer in { name } "
402+
402403 # Verify K and V quantizers are enabled (main purpose of KV cache configs)
403404 assert module .k_bmm_quantizer .is_enabled , f"K quantizer not enabled in { name } "
404405 assert module .v_bmm_quantizer .is_enabled , f"V quantizer not enabled in { name } "
405-
406+
406407 assert te_attention_found , "No TEDotProductAttention with KV cache quantizers found in model"
407-
408+
408409 # Quick smoke test that forward still works
409410 output = forward_fn (quantized_model )
410411 assert output is not None , "Forward pass failed"
411-
412+
412413
413414def _test_kv_cache_sharded_state_dict_helper (tmp_path , config , rank , size ):
414415 """Helper for testing KV cache quantization with sharded state dict save/load."""
415416 # Disable output_layer quantization (same as other sharded state dict tests)
416417 config ["quant_cfg" ]["*output_layer*" ] = {"enable" : False }
417-
418- initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED )
419-
418+
419+ initialize_for_megatron (
420+ tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 , seed = SEED
421+ )
422+
420423 # Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
421424 model_ref = get_mcore_gpt_model (
422425 tensor_model_parallel_size = size ,
@@ -426,7 +429,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
426429 vocab_size = 64 ,
427430 transformer_impl = "modelopt" , # CRITICAL: Use TEDotProductAttention
428431 ).cuda ()
429-
432+
430433 model_test = get_mcore_gpt_model (
431434 tensor_model_parallel_size = size ,
432435 num_layers = 2 ,
@@ -435,29 +438,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
435438 vocab_size = 64 ,
436439 transformer_impl = "modelopt" ,
437440 ).cuda ()
438-
439- prompt_tokens = torch .randint (0 , model_ref .vocab_size , (2 , model_ref .max_sequence_length )).cuda ()
440-
441+
442+ prompt_tokens = torch .randint (
443+ 0 , model_ref .vocab_size , (2 , model_ref .max_sequence_length )
444+ ).cuda ()
445+
441446 def forward_fn (model ):
442447 return megatron_prefill (model , prompt_tokens )
443-
448+
444449 # Quantize the reference model
445450 model_ref = mtq .quantize (model_ref , config , forward_fn )
446-
451+
447452 # CRITICAL: model_test must also be quantized with the same config
448453 # Otherwise it won't have the KV cache quantizer keys when loading state dict
449454 model_test = mtq .quantize (model_test , config , forward_fn )
450-
455+
451456 # Verify KV cache quantizers were created
452457 kv_quantizers_found = False
453458 for name , module in model_ref .named_modules ():
454- if hasattr (module , ' k_bmm_quantizer' ) and hasattr (module , ' v_bmm_quantizer' ):
459+ if hasattr (module , " k_bmm_quantizer" ) and hasattr (module , " v_bmm_quantizer" ):
455460 kv_quantizers_found = True
456461 assert module .k_bmm_quantizer .is_enabled , f"K quantizer not enabled in { name } "
457462 assert module .v_bmm_quantizer .is_enabled , f"V quantizer not enabled in { name } "
458-
463+
459464 assert kv_quantizers_found , "No KV cache quantizers found in quantized model"
460-
465+
461466 # Test sharded state dict save/load
462467 sharded_state_dict_test_helper (
463468 tmp_path ,
@@ -467,32 +472,38 @@ def forward_fn(model):
467472 meta_device = False ,
468473 version = None ,
469474 )
470-
475+
471476 # Verify KV cache quantizers are restored correctly in model_test
472477 for (name_ref , module_ref ), (name_test , module_test ) in zip (
473478 model_ref .named_modules (), model_test .named_modules ()
474479 ):
475- if hasattr (module_ref , 'k_bmm_quantizer' ):
476- assert hasattr (module_test , 'k_bmm_quantizer' ), f"K quantizer missing after restore in { name_test } "
477- assert hasattr (module_test , 'v_bmm_quantizer' ), f"V quantizer missing after restore in { name_test } "
478-
480+ if hasattr (module_ref , "k_bmm_quantizer" ):
481+ assert hasattr (module_test , "k_bmm_quantizer" ), (
482+ f"K quantizer missing after restore in { name_test } "
483+ )
484+ assert hasattr (module_test , "v_bmm_quantizer" ), (
485+ f"V quantizer missing after restore in { name_test } "
486+ )
487+
479488 # Check that quantizer states match
480- if hasattr (module_ref .k_bmm_quantizer , '_amax' ):
481- assert hasattr (module_test .k_bmm_quantizer , '_amax' ), f"K quantizer _amax missing in { name_test } "
489+ if hasattr (module_ref .k_bmm_quantizer , "_amax" ):
490+ assert hasattr (module_test .k_bmm_quantizer , "_amax" ), (
491+ f"K quantizer _amax missing in { name_test } "
492+ )
482493 if module_ref .k_bmm_quantizer ._amax is not None :
483494 assert torch .allclose (
484- module_ref .k_bmm_quantizer ._amax ,
485- module_test .k_bmm_quantizer ._amax
495+ module_ref .k_bmm_quantizer ._amax , module_test .k_bmm_quantizer ._amax
486496 ), f"K quantizer _amax mismatch in { name_test } "
487-
488- if hasattr (module_ref .v_bmm_quantizer , '_amax' ):
489- assert hasattr (module_test .v_bmm_quantizer , '_amax' ), f"V quantizer _amax missing in { name_test } "
497+
498+ if hasattr (module_ref .v_bmm_quantizer , "_amax" ):
499+ assert hasattr (module_test .v_bmm_quantizer , "_amax" ), (
500+ f"V quantizer _amax missing in { name_test } "
501+ )
490502 if module_ref .v_bmm_quantizer ._amax is not None :
491503 assert torch .allclose (
492- module_ref .v_bmm_quantizer ._amax ,
493- module_test .v_bmm_quantizer ._amax
504+ module_ref .v_bmm_quantizer ._amax , module_test .v_bmm_quantizer ._amax
494505 ), f"V quantizer _amax mismatch in { name_test } "
495-
506+
496507
497508@pytest .mark .parametrize (
498509 "config" ,
@@ -503,16 +514,14 @@ def forward_fn(model):
503514)
504515def test_kv_cache_quant (config ):
505516 """Verify KV cache quantization works correctly with TEDotProductAttention.
506-
507- This test ensures TEDotProductAttention is properly registered and gets the
517+
518+ This test ensures TEDotProductAttention is properly registered and gets the
508519 expected q/k/v_bmm_quantizers when using KV cache configs.
509-
520+
510521 Note: This test requires Transformer Engine to be installed since TEDotProductAttention
511522 is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
512523 """
513- spawn_multiprocess_job (
514- size = 1 , job = partial (_test_kv_cache_quant_helper , config ), backend = "nccl"
515- )
524+ spawn_multiprocess_job (size = 1 , job = partial (_test_kv_cache_quant_helper , config ), backend = "nccl" )
516525
517526
518527@pytest .mark .parametrize (
@@ -524,7 +533,7 @@ def test_kv_cache_quant(config):
524533)
525534def test_kv_cache_sharded_state_dict (tmp_path , config ):
526535 """Test KV cache quantization with sharded state dict save/load.
527-
536+
528537 This test verifies the complete workflow of saving and loading KV cache quantized
529538 models with distributed checkpointing, ensuring quantizer states are properly
530539 preserved across the save/load cycle.
@@ -533,5 +542,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config):
533542 spawn_multiprocess_job (
534543 size = size ,
535544 job = partial (_test_kv_cache_sharded_state_dict_helper , tmp_path , config ),
536- backend = "nccl"
545+ backend = "nccl" ,
537546 )
0 commit comments