@@ -60,7 +60,7 @@ def test_large_outlier_bert(
6060 # Third Party
6161 import torch
6262
63- # Break every tensor channel with a large magnitude outlier
63+ # Break every tensor channel with a large magnitude outlier - should work for per tensor too
6464 for k , v in model_tiny_bert .state_dict ().items ():
6565 if k .endswith (".weight" ) and any (n in k for n in bert_linear_names ):
6666 v [:, 0 ] = 1.21
@@ -69,11 +69,15 @@ def test_large_outlier_bert(
6969 qcfg_bert ["recompute_narrow_weights" ] = True
7070 qmodel_prep (model_tiny_bert , input_tiny , qcfg_bert , use_dynamo = True )
7171
72+ # Reduce perCh or perTensor
73+ stddev_dim = - 1 if "perCh" in qcfg_bert ["qw_mode" ] else None
74+
7275 # Qmax should break the quantization with an outlier to have skinny distribution
7376 layer2stdev : dict [str , torch .Tensor ] = {}
7477 for k , v in model_tiny_bert .state_dict ().items ():
7578 if k .endswith (".weight" ) and any (n in k for n in bert_linear_names ):
76- layer2stdev [k ] = v .to (torch .float32 ).std (dim = - 1 )
79+ # Collect perCh or perTensor std dev
80+ layer2stdev [k ] = v .to (torch .float32 ).std (dim = stddev_dim )
7781
7882 save_for_aiu (model_tiny_bert , qcfg = qcfg_bert , verbose = True )
7983 state_dict = load_state_dict ()
@@ -82,8 +86,9 @@ def test_large_outlier_bert(
8286 for k , v in state_dict .items ():
8387 if k .endswith (".weight" ) and any (n in k for n in bert_linear_names ):
8488 perCh_stdev_model = layer2stdev .get (k )
85- perCh_stdev_loaded = v .to (torch .float32 ).std (dim = - 1 )
89+ perCh_stdev_loaded = v .to (torch .float32 ).std (dim = stddev_dim )
8690
91+ # SAWB stddev should be at least as good as Qmax stddev w/ outlier
8792 assert torch .all (perCh_stdev_loaded >= perCh_stdev_model )
8893
8994
0 commit comments