Skip to content

Commit 7a6e1b0

Browse files
committed
smooth attention
Signed-off-by: Iqbal Saraf <[email protected]>
1 parent 31c4e2f commit 7a6e1b0

File tree

3 files changed

+44
-22
lines changed

3 files changed

+44
-22
lines changed

fms_mo/modules/bmm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ def __init__(
8282
self.m2_bounded = m2_bounded
8383
self.qm1_mode = qm1_mode
8484
self.qm2_mode = qm2_mode
85-
85+
self.smooth_attn= qcfg.get("smooth_attn", False)
86+
self.smooth_attn_alpha = qcfg.get("smooth_attn_alpha", 0.5)
87+
if self.smooth_attn_alpha < 0 or self.smooth_attn_alpha > 1:
88+
raise ValueError(
89+
"smooth_attn_alpha must be in range [0,1] "
90+
f"(given: {self.smooth_attn_alpha})"
91+
)
8692
self.m1_clip_init_val = kwargs.get(
8793
"m1_clip_init_val", qcfg.get("m1_clip_init_val", 1.0)
8894
)
@@ -191,6 +197,12 @@ def forward(self, m1, m2):
191197
Returns:
192198
torch.Tensor: Output tensor after quantized bmm.
193199
"""
200+
if self.smooth_attn:
201+
attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5)
202+
attn_scales = attn_scales.pow(self.smooth_attn_alpha)
203+
m1 *= attn_scales
204+
m2 /= attn_scales.reshape(1,1,m2.shape[2], 1)
205+
194206
# pylint: disable = access-member-before-definition
195207
if self.calib_counter:
196208
with torch.no_grad():

fms_mo/training_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ class FMSMOArguments(TypeChecker):
164164
bmm2_qm1_mode: str = field(default="pact", metadata={"help": ("bmm2.m1 quanitzer")})
165165
bmm2_qm2_mode: str = field(default="pact", metadata={"help": ("bmm2.m1 quanitzer")})
166166
smoothq_alpha: float = field(default=0.65, metadata={"help": "smooth quant alpha"})
167+
smooth_attn_alpha: float = field(default=0.5, metadata={"help": "smooth attention alpha"})
168+
smooth_attn: bool = field(default=False, metadata={"help": "enable smooth attention"})
167169
qmodel_calibration: int = field(
168170
default=0,
169171
metadata={"help": "Num of batches for Qmodel calibration, using model copy."},

fms_mo/utils/qconfig_utils.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def config_defaults() -> dict:
149149
"smoothq": False,
150150
"smoothq_scale_layers": [],
151151
"smoothq_act_scale_path": None,
152+
"smooth_attn": False,
152153
# Other vars
153154
"which2patch_contextmanager": None,
154155
"force_stop_if_qbmm_auto_check_failed": False,
@@ -940,11 +941,16 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
940941
"pactsym+",
941942
"max",
942943
"minmax",
944+
"maxbmm",
943945
"maxsym",
944946
"pertokenmax",
945947
"lsq+",
946948
"fix",
947949
"brecq",
950+
]
951+
shared_modes = [
952+
"max_perToken",
953+
"max_perCh",
948954
# fp8_e4m3
949955
"fp8_e4m3_sat",
950956
"fp8_e4m3_scale",
@@ -981,33 +987,34 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
981987
"brecq",
982988
"adaround",
983989
"pertokenmax",
984-
# fp8_e4m3
985-
"fp8_e4m3_sat",
986-
"fp8_e4m3_scale",
987-
"fp8_e4m3_sat_perCh",
988-
"fp8_e4m3_scale_perCh",
989-
"fp8_e4m3_sat_perToken",
990-
"fp8_e4m3_scale_perToken",
991-
# fp8_e5m2
992-
"fp8_e5m2_sat",
993-
"fp8_e5m2_scale",
994-
"fp8_e5m2_sat_perCh",
995-
"fp8_e5m2_scale_perCh",
996-
"fp8_e5m2_sat_perToken",
997-
"fp8_e5m2_scale_perToken",
990+
# # fp8_e4m3
991+
# "fp8_e4m3_sat",
992+
# "fp8_e4m3_scale",
993+
# "fp8_e4m3_sat_perCh",
994+
# "fp8_e4m3_scale_perCh",
995+
# "fp8_e4m3_sat_perToken",
996+
# "fp8_e4m3_scale_perToken",
997+
# # fp8_e5m2
998+
# "fp8_e5m2_sat",
999+
# "fp8_e5m2_scale",
1000+
# "fp8_e5m2_sat_perCh",
1001+
# "fp8_e5m2_scale_perCh",
1002+
# "fp8_e5m2_sat_perToken",
1003+
# "fp8_e5m2_scale_perToken",
9981004
]
9991005
bmm_mode_settings = [
10001006
"pact",
10011007
"pactsym",
10021008
"pactsym+",
10031009
"maxsym",
1010+
"maxbmm",
10041011
"max",
10051012
"minmax",
10061013
"pertokenmax",
1007-
"fp8_e4m3_sat",
1008-
"fp8_e4m3_scale_perToken",
1009-
"fp8_e5m2_sat",
1010-
"fp8_e5m2_scale_perToken",
1014+
# "fp8_e4m3_sat",
1015+
# "fp8_e4m3_scale_perToken",
1016+
# "fp8_e5m2_sat",
1017+
# "fp8_e5m2_scale_perToken",
10111018
]
10121019

10131020
# Get strings in config for qa_modes, qw_modes, bmm_modes
@@ -1043,15 +1050,15 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
10431050
# Check each for correct ranges
10441051
for qa_mode_str in qa_modes_str:
10451052
qa_mode = config.get(qa_mode_str, "pact+")
1046-
if not qa_mode in (qa_mode_settings + mx_spec_config_modes):
1053+
if not qa_mode in (qa_mode_settings + mx_spec_config_modes + shared_modes):
10471054
raise ValueError(
10481055
f"{qa_mode_str} = {qa_mode} is not set to one of the following: "
10491056
f"{qa_mode_settings + mx_spec_config_modes}"
10501057
)
10511058

10521059
for qw_mode_str in qw_modes_str:
10531060
qw_mode = config.get(qw_mode_str, "sawb+")
1054-
if not qw_mode in (qw_mode_settings + mx_spec_config_modes):
1061+
if not qw_mode in (qw_mode_settings + mx_spec_config_modes + shared_modes):
10551062
raise ValueError(
10561063
f"{qw_mode_str} = {qw_mode} is not set to one of the following: "
10571064
f"{qw_mode_settings + mx_spec_config_modes}"
@@ -1063,7 +1070,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
10631070
bmm_mode_consistency += bmm_mode.startswith("mx_")
10641071
# mx_specs doesn't have 4 individual bmmX_qmY_modes, it re-uses w and a fmt instead.
10651072
# We will keep them in qcfg (with "mx_" prefix NOT removed).
1066-
if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes):
1073+
if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes + shared_modes):
10671074
raise ValueError(
10681075
f"{bmm_mode_str} = {bmm_mode} is not set to one of the following: "
10691076
f"{bmm_mode_settings + mx_spec_config_modes}"
@@ -1101,6 +1108,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
11011108
"qskip_large_mag_layers",
11021109
"recompute_narrow_weights",
11031110
"smoothq",
1111+
"smooth_attn",
11041112
]
11051113
for boolean_var_str in boolean_vars_str:
11061114
boolean_var = config.get(

0 commit comments

Comments
 (0)