Skip to content

Commit ca3f763

Browse files
improve transformers tracing for last layers
Signed-off-by: cliu-us <[email protected]>
1 parent 136c12f commit ca3f763

File tree

3 files changed

+41
-22
lines changed

3 files changed

+41
-22
lines changed

fms_mo/dq.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
105105
or not isinstance(model_args.torch_dtype, str)
106106
else getattr(torch, model_args.torch_dtype)
107107
)
108+
# NOTE for models that cannot fit in 1 GPU, keep it on CPU and use block-wise calibration.
109+
# or leverage HF's device_map="auto", BUT tracing will not work properly with "auto"
110+
total_gpu_memory = 1e-5
111+
if torch.cuda.is_available():
112+
total_gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
113+
108114
model = AutoModelForCausalLM.from_pretrained(
109115
model_args.model_name_or_path,
110116
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -113,8 +119,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
113119
revision="main",
114120
use_auth_token=True if model_args.use_auth_token else None,
115121
torch_dtype=torch_dtype,
116-
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
117-
device_map="auto" if model_args.low_cpu_mem_usage else None,
122+
device_map=model_args.device_map,
123+
low_cpu_mem_usage=bool(model_args.device_map),
118124
)
119125

120126
embedding_size = model.get_input_embeddings().weight.shape[0]
@@ -125,11 +131,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
125131
logger.info(f"Model is at {model.device} after intialization")
126132
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
127133
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
128-
# for models that cannot fit in 1 GPU, keep it on CPU and use block-wise calibration.
129-
# or leverage HF's device_map="auto"
130-
total_gpu_memory = 1e-5
131-
if torch.cuda.is_available():
132-
total_gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
134+
133135
model_size = model_size_Wb(model, unit="GB")
134136
gpu_mem_util_per = model_size / total_gpu_memory
135137

@@ -145,7 +147,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
145147
name in model_args.model_name_or_path for name in known_large_models
146148
) or (gpu_mem_util_per > 0.7)
147149
dev = "cpu" if qcfg["large_model"] else "cuda"
148-
model.to(dev)
150+
if model_args.device_map is None:
151+
model.to(dev)
149152

150153
if hasattr(model.config, "model_type"):
151154
qcfg["model_type"] = model.config.model_type

fms_mo/fx/dynamo_utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,17 +1006,10 @@ def cus_backend_model_analyzer(
10061006
if len(all_conv) > 0:
10071007
skip_candidates += find_conv_on_shortcut_gm(gm_fx, lut_fx_mod_name_to_org)
10081008

1009-
# Check 2. first/last, see Note 2 and 3
1009+
# Check 2. first/last, see Note 2 and 3, NOTE that transformers are handled differently
10101010
if qcfg["N_backend_called"] > 1:
10111011
skip_candidates += []
1012-
elif is_transformers:
1013-
_, last_only = find_1st_last_gm(
1014-
gm_fx,
1015-
return_1st_last_sep=True,
1016-
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
1017-
)
1018-
skip_candidates += last_only
1019-
else:
1012+
elif not is_transformers:
10201013
# see Note 4
10211014
skip_candidates += find_1st_last_gm(
10221015
gm_fx, lut_fx_mod_name_to_org=lut_fx_mod_name_to_org
@@ -1082,6 +1075,7 @@ def cus_backend_model_analyzer(
10821075
model_to_be_traced = model
10831076
model_param_size = 999
10841077

1078+
is_transformers = issubclass(type(model), PreTrainedModel)
10851079
if model_param_size > 1:
10861080
# Standard
10871081
import sys
@@ -1091,7 +1085,7 @@ def cus_backend_model_analyzer(
10911085

10921086
cus_bknd = partial(
10931087
cus_backend_model_analyzer,
1094-
is_transformers=issubclass(type(model), PreTrainedModel),
1088+
is_transformers=is_transformers,
10951089
plotsvg=plotsvg,
10961090
)
10971091

@@ -1104,6 +1098,27 @@ def cus_backend_model_analyzer(
11041098
if "bmm_prep" not in qcfg:
11051099
qcfg["bmm_prep"] = {"which2patch_contextmanager": None, "layers_with_bmm": {}}
11061100

1101+
if is_transformers:
1102+
# NOTE simplified method to determine 1st/last modules for transformers.
1103+
# will not work if model has multiple parallel heads at the end, e.g. obj det
1104+
def call_seq_hook(mod, *_args, **_kwargs):
1105+
qcfg["mod_call_seq"].append(lut_weight2modname[mod.weight])
1106+
1107+
h_hooks = []
1108+
qcfg["mod_call_seq"] = []
1109+
for n, m in model.named_modules():
1110+
if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)):
1111+
h_hooks.append(m.register_forward_hook(call_seq_hook))
1112+
1113+
with torch.no_grad():
1114+
model(**sample_inp)
1115+
1116+
for h in h_hooks:
1117+
h.remove()
1118+
1119+
# only add last layer
1120+
qcfg["qskip_layer_name"] += qcfg["mod_call_seq"][-1]
1121+
11071122
with torch.no_grad():
11081123
model_opt = torch.compile(
11091124
model_to_be_traced,

fms_mo/training_args.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ class ModelArguments(TypeChecker):
5656

5757
model_name_or_path: str = field(default="facebook/opt-125m")
5858
torch_dtype: str = field(default="bfloat16")
59-
low_cpu_mem_usage: bool = field(
60-
default=False,
59+
device_map: Optional[str] = field(
60+
default=None,
6161
metadata={
62-
"help": "When set to True, leverage device_map='auto' and let HF to move modules"
63-
"between cpu and cuda automatically during inference."
62+
"help": "can be 'auto', 'balanced', 'balanced_low_0', 'sequential' or something like"
63+
" {'encoder':'cuda:1', 'decoder': 'cuda:2'}.\n"
64+
"HF will try to move modules between cpu and cuda automatically during inference."
6465
},
6566
)
6667
use_fast_tokenizer: bool = field(

0 commit comments

Comments
 (0)