@@ -948,3 +948,91 @@ def test_compare_large_sigma_rel_with_traditional_ema():
948
948
"ema_values" : ema_values ,
949
949
"best_beta" : best_beta ,
950
950
}
951
+
952
+
953
+ def test_compare_very_small_sigma_rel_with_precise_beta ():
954
+ """
955
+ Test to compare sigma_rel=0.001 with beta values around 0.4 for more precision.
956
+
957
+ This test:
958
+ 1. Runs PostHocEMA with sigma_rel=0.001
959
+ 2. Runs traditional EMA with a range of beta values around 0.4
960
+ 3. Prints the comparison to help identify the most precise match
961
+ """
962
+ # Common parameters
963
+ num_steps = 1000
964
+
965
+ # Create models
966
+ model_posthoc = SingleParamModel (initial_value = 0.0 )
967
+
968
+ # Import our implementation
969
+ from posthoc_ema import PostHocEMA
970
+
971
+ # Create our PostHocEMA instance
972
+ our_ema = PostHocEMA .from_model (
973
+ model_posthoc ,
974
+ "./test-our-ema" ,
975
+ sigma_rels = (0.001 ,),
976
+ update_every = 1 , # Update every step for more precision
977
+ checkpoint_every = 50 ,
978
+ update_after_step = 0 , # Start immediately
979
+ )
980
+
981
+ # Traditional EMA beta values to test, focusing around 0.4
982
+ beta_values = [0.35 , 0.36 , 0.37 , 0.38 , 0.39 , 0.40 , 0.41 , 0.42 , 0.43 , 0.44 , 0.45 ]
983
+
984
+ # Initialize EMA values for each beta
985
+ ema_values = {beta : 0.0 for beta in beta_values }
986
+
987
+ # Gradually update the model from 0 to 1
988
+ for step in range (num_steps ):
989
+ # Linear interpolation from 0 to 1
990
+ target_value = step / (num_steps - 1 )
991
+
992
+ with torch .no_grad ():
993
+ model_posthoc .param .copy_ (torch .tensor ([target_value ], dtype = torch .float32 ))
994
+
995
+ # Update traditional EMA values
996
+ for beta in beta_values :
997
+ ema_values [beta ] = beta * ema_values [beta ] + (1 - beta ) * target_value
998
+
999
+ our_ema .update_ (model_posthoc )
1000
+
1001
+ # Get PostHocEMA value
1002
+ with our_ema .state_dict (sigma_rel = 0.001 ) as state_dict :
1003
+ posthoc_value = state_dict ["param" ].item ()
1004
+
1005
+ print (f"\n PostHocEMA value for sigma_rel=0.001: { posthoc_value :.6f} " )
1006
+
1007
+ # Print comparison with traditional EMA values
1008
+ print ("\n Comparison with beta values around 0.4:" )
1009
+ for beta in sorted (beta_values ):
1010
+ diff = abs (ema_values [beta ] - posthoc_value )
1011
+ print (
1012
+ f" Beta={ beta :.2f} : { ema_values [beta ]:.6f} (diff: { diff :.6f} , { diff / posthoc_value * 100 :.4f} %)"
1013
+ )
1014
+
1015
+ # Find the closest match
1016
+ best_beta = None
1017
+ best_diff = float ("inf" )
1018
+ for beta in beta_values :
1019
+ diff = abs (ema_values [beta ] - posthoc_value )
1020
+ if diff < best_diff :
1021
+ best_diff = diff
1022
+ best_beta = beta
1023
+
1024
+ print (
1025
+ f"\n Closest match: beta={ best_beta :.2f} with difference { best_diff :.6f} ({ best_diff / posthoc_value * 100 :.4f} %)"
1026
+ )
1027
+
1028
+ # Determine decay speed category
1029
+ decay_speed = "Very fast decay"
1030
+
1031
+ print (f"\n README mapping entry:" )
1032
+ print (f"beta = { best_beta :.2f} # { decay_speed } -> sigma_rel ≈ 0.001" )
1033
+
1034
+ return {
1035
+ "posthoc_value" : posthoc_value ,
1036
+ "ema_values" : ema_values ,
1037
+ "best_beta" : best_beta ,
1038
+ }
0 commit comments