Skip to content

Commit 6be43cf

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Refactor dtype handling in export_llama (#9430)
Summary: No more converting from fp32 -> checkpoint dtype (fp16 or lower) -> back to dtype override (fp32), where we are losing precision on buffers. Also cleans up the entire dtype, now it only occurs outside of model.py, who's responsibility should just be for loading the model. Differential Revision: D71515138
1 parent 45219f3 commit 6be43cf

File tree

2 files changed

+44
-74
lines changed

2 files changed

+44
-74
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 37 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def build_args_parser() -> argparse.ArgumentParser:
322322
default="fp32",
323323
type=str,
324324
choices=["fp32", "fp16", "bf16"],
325-
help="Override the dtype of the model (default is the checkpoint dtype)."
325+
help="Provide the dtype of the model."
326326
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
327327
)
328328

@@ -565,43 +565,40 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
565565
output_dir_path = canonical_path(args.output_dir, dir=True)
566566
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
567567

568-
# dtype override
569-
if args.dtype_override is not None:
570-
dtype_override = DType[args.dtype_override]
571-
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
572-
dtype_override = DType["fp16"]
573-
else:
574-
dtype_override = None
575-
576-
return (
577-
_load_llama_model(
578-
args.model,
579-
checkpoint=checkpoint_path,
580-
checkpoint_dir=checkpoint_dir,
581-
params_path=params_path,
582-
use_kv_cache=args.use_kv_cache,
583-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
584-
generate_full_logits=args.generate_full_logits,
585-
weight_type=weight_type,
586-
enable_dynamic_shape=args.enable_dynamic_shape,
587-
calibration_tasks=args.calibration_tasks,
588-
calibration_limit=args.calibration_limit,
589-
calibration_seq_length=args.calibration_seq_length,
590-
calibration_data=args.calibration_data,
591-
tokenizer_path=args.tokenizer_path,
592-
verbose=args.verbose,
593-
max_seq_len=args.max_seq_length,
594-
max_context_len=args.max_context_length,
595-
input_prune_map_path=args.input_prune_map,
596-
output_prune_map_path=args.output_prune_map,
597-
metadata_str=args.metadata,
598-
dtype_override=dtype_override,
599-
args=args,
600-
)
601-
.set_output_dir(output_dir_path)
602-
.source_transform(_get_source_transforms(args.model, dtype_override, args))
568+
# Convert dtype override string arg to actual type.
569+
dtype_override = DType[args.dtype_override]
570+
571+
edge_manager = _load_llama_model(
572+
args.model,
573+
checkpoint=checkpoint_path,
574+
checkpoint_dir=checkpoint_dir,
575+
params_path=params_path,
576+
use_kv_cache=args.use_kv_cache,
577+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
578+
generate_full_logits=args.generate_full_logits,
579+
weight_type=weight_type,
580+
enable_dynamic_shape=args.enable_dynamic_shape,
581+
calibration_tasks=args.calibration_tasks,
582+
calibration_limit=args.calibration_limit,
583+
calibration_seq_length=args.calibration_seq_length,
584+
calibration_data=args.calibration_data,
585+
tokenizer_path=args.tokenizer_path,
586+
verbose=args.verbose,
587+
max_seq_len=args.max_seq_length,
588+
max_context_len=args.max_context_length,
589+
input_prune_map_path=args.input_prune_map,
590+
output_prune_map_path=args.output_prune_map,
591+
metadata_str=args.metadata,
592+
dtype_override=dtype_override,
593+
args=args,
603594
)
604595

596+
# At this point, the model is loaded in the default fp32.
597+
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
598+
edge_manager.set_output_dir(output_dir_path).source_transform(_get_source_transforms(args.model, dtype_override, args))
599+
600+
return edge_manager
601+
605602

606603
def get_quantizer_and_quant_params(args):
607604
pt2e_quant_params = get_pt2e_quantization_params(
@@ -1006,6 +1003,8 @@ def _load_llama_model(
10061003
else:
10071004
raise ValueError(f"{modelname} is not a valid Llama model.")
10081005

1006+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
1007+
10091008
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
10101009
EagerModelFactory.create_model(
10111010
module_name,
@@ -1022,41 +1021,16 @@ def _load_llama_model(
10221021
enable_dynamic_shape=enable_dynamic_shape,
10231022
input_prune_map_path=input_prune_map_path,
10241023
output_prune_map_path=output_prune_map_path,
1024+
dtype=torch_dtype,
10251025
args=args,
10261026
)
10271027
)
1028-
if dtype_override:
1029-
assert isinstance(
1030-
dtype_override, DType
1031-
), "Override dtype needs to be of type <DType>"
1032-
torch_dtype = dtype_override.to_torch_dtype()
1033-
logging.info(f"model.to {torch_dtype}")
1034-
model = model.to(dtype=torch_dtype)
1035-
dtype = dtype_override
1036-
else:
1037-
state_dict = model.state_dict()
1038-
dtype = state_dict[next(iter(state_dict))].dtype
1039-
assert dtype in [
1040-
torch.bfloat16,
1041-
torch.float16,
1042-
torch.float32,
1043-
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
1044-
logging.info(f"Loaded model with dtype={dtype}")
1045-
1046-
if dtype == torch.bfloat16:
1047-
dtype = DType.bf16
1048-
elif dtype == torch.float16:
1049-
dtype = DType.fp16
1050-
elif dtype == torch.float32:
1051-
dtype = DType.fp32
1052-
else:
1053-
raise ValueError(f"Unsupported dtype {dtype}")
10541028

10551029
return LLMEdgeManager(
10561030
model=model,
10571031
modelname=modelname,
10581032
max_seq_len=model.max_seq_len,
1059-
dtype=dtype,
1033+
dtype=dtype_override,
10601034
use_kv_cache=use_kv_cache,
10611035
generate_full_logits=generate_full_logits,
10621036
example_inputs=example_inputs,

examples/models/llama/model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def __init__(self, **kwargs):
122122
"""
123123
)
124124

125-
# Get checkpoint dtype.
126-
self.dtype = get_checkpoint_dtype(checkpoint)
127-
128125
with open(params_path, "r") as f:
129126
params = json.loads(f.read())
130127
output_prune_map = None
@@ -171,7 +168,9 @@ def __init__(self, **kwargs):
171168
# Within the device="meta" context, tensors that are created do not carry data.
172169
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173170
with torch.device("meta"):
171+
# Model itself is loaded in default dtype, fp32.
174172
self.model_ = Transformer(model_args)
173+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
175174

176175
if "int8" in str(checkpoint_path):
177176
print("Using int8 weight-only quantization!")
@@ -241,6 +240,10 @@ def __init__(self, **kwargs):
241240
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
242241
# Because we are using device="meta", tensors do not have memory associated with them
243242
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
243+
244+
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245+
# by default initialized to fp32. This is fine because every other supported type
246+
# losslessly converts to fp32, so we don't lose precision here.
244247
missing, unexpected = self.model_.load_state_dict(
245248
checkpoint,
246249
strict=False,
@@ -277,14 +280,7 @@ def __init__(self, **kwargs):
277280
self.model_ = prune_output_vocab(self.model_, output_prune_map)
278281

279282
def get_eager_model(self) -> torch.nn.Module:
280-
if self.dtype:
281-
# convert to the type of the provided checkpoint
282-
# input and output are torch.long, so signature unchanged
283-
return self.model_.to(self.dtype)
284-
else:
285-
# int8 quantization code has some bf16,
286-
# switch all to FP32
287-
return self.model_.to(torch.float32)
283+
return self.model_
288284

289285
def get_example_inputs(self):
290286
if self.use_kv_cache:

0 commit comments

Comments
 (0)