@@ -149,6 +149,7 @@ def config_defaults() -> dict:
149149 "smoothq" : False ,
150150 "smoothq_scale_layers" : [],
151151 "smoothq_act_scale_path" : None ,
152+ "smooth_attn" : False ,
152153 # Other vars
153154 "which2patch_contextmanager" : None ,
154155 "force_stop_if_qbmm_auto_check_failed" : False ,
@@ -940,11 +941,16 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
940941 "pactsym+" ,
941942 "max" ,
942943 "minmax" ,
944+ "maxbmm" ,
943945 "maxsym" ,
944946 "pertokenmax" ,
945947 "lsq+" ,
946948 "fix" ,
947949 "brecq" ,
950+ ]
951+ shared_modes = [
952+ "max_perToken" ,
953+ "max_perCh" ,
948954 # fp8_e4m3
949955 "fp8_e4m3_sat" ,
950956 "fp8_e4m3_scale" ,
@@ -981,33 +987,34 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
981987 "brecq" ,
982988 "adaround" ,
983989 "pertokenmax" ,
984- # fp8_e4m3
985- "fp8_e4m3_sat" ,
986- "fp8_e4m3_scale" ,
987- "fp8_e4m3_sat_perCh" ,
988- "fp8_e4m3_scale_perCh" ,
989- "fp8_e4m3_sat_perToken" ,
990- "fp8_e4m3_scale_perToken" ,
991- # fp8_e5m2
992- "fp8_e5m2_sat" ,
993- "fp8_e5m2_scale" ,
994- "fp8_e5m2_sat_perCh" ,
995- "fp8_e5m2_scale_perCh" ,
996- "fp8_e5m2_sat_perToken" ,
997- "fp8_e5m2_scale_perToken" ,
990+ # # fp8_e4m3
991+ # "fp8_e4m3_sat",
992+ # "fp8_e4m3_scale",
993+ # "fp8_e4m3_sat_perCh",
994+ # "fp8_e4m3_scale_perCh",
995+ # "fp8_e4m3_sat_perToken",
996+ # "fp8_e4m3_scale_perToken",
997+ # # fp8_e5m2
998+ # "fp8_e5m2_sat",
999+ # "fp8_e5m2_scale",
1000+ # "fp8_e5m2_sat_perCh",
1001+ # "fp8_e5m2_scale_perCh",
1002+ # "fp8_e5m2_sat_perToken",
1003+ # "fp8_e5m2_scale_perToken",
9981004 ]
9991005 bmm_mode_settings = [
10001006 "pact" ,
10011007 "pactsym" ,
10021008 "pactsym+" ,
10031009 "maxsym" ,
1010+ "maxbmm" ,
10041011 "max" ,
10051012 "minmax" ,
10061013 "pertokenmax" ,
1007- "fp8_e4m3_sat" ,
1008- "fp8_e4m3_scale_perToken" ,
1009- "fp8_e5m2_sat" ,
1010- "fp8_e5m2_scale_perToken" ,
1014+ # "fp8_e4m3_sat",
1015+ # "fp8_e4m3_scale_perToken",
1016+ # "fp8_e5m2_sat",
1017+ # "fp8_e5m2_scale_perToken",
10111018 ]
10121019
10131020 # Get strings in config for qa_modes, qw_modes, bmm_modes
@@ -1043,15 +1050,15 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
10431050 # Check each for correct ranges
10441051 for qa_mode_str in qa_modes_str :
10451052 qa_mode = config .get (qa_mode_str , "pact+" )
1046- if not qa_mode in (qa_mode_settings + mx_spec_config_modes ):
1053+ if not qa_mode in (qa_mode_settings + mx_spec_config_modes + shared_modes ):
10471054 raise ValueError (
10481055 f"{ qa_mode_str } = { qa_mode } is not set to one of the following: "
10491056 f"{ qa_mode_settings + mx_spec_config_modes } "
10501057 )
10511058
10521059 for qw_mode_str in qw_modes_str :
10531060 qw_mode = config .get (qw_mode_str , "sawb+" )
1054- if not qw_mode in (qw_mode_settings + mx_spec_config_modes ):
1061+ if not qw_mode in (qw_mode_settings + mx_spec_config_modes + shared_modes ):
10551062 raise ValueError (
10561063 f"{ qw_mode_str } = { qw_mode } is not set to one of the following: "
10571064 f"{ qw_mode_settings + mx_spec_config_modes } "
@@ -1063,7 +1070,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
10631070 bmm_mode_consistency += bmm_mode .startswith ("mx_" )
10641071 # mx_specs doesn't have 4 individual bmmX_qmY_modes, it re-uses w and a fmt instead.
10651072 # We will keep them in qcfg (with "mx_" prefix NOT removed).
1066- if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes ):
1073+ if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes + shared_modes ):
10671074 raise ValueError (
10681075 f"{ bmm_mode_str } = { bmm_mode } is not set to one of the following: "
10691076 f"{ bmm_mode_settings + mx_spec_config_modes } "
@@ -1101,6 +1108,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
11011108 "qskip_large_mag_layers" ,
11021109 "recompute_narrow_weights" ,
11031110 "smoothq" ,
1111+ "smooth_attn" ,
11041112 ]
11051113 for boolean_var_str in boolean_vars_str :
11061114 boolean_var = config .get (
0 commit comments