Skip to content

Commit e68b028

Browse files
committed
initial
1 parent 643c381 commit e68b028

File tree

3 files changed

+42
-124
lines changed

3 files changed

+42
-124
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 40 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -561,42 +561,49 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561561
output_dir_path = canonical_path(args.output_dir, dir=True)
562562
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
563563

564-
# dtype override
565-
if args.dtype_override is not None:
566-
dtype_override = DType[args.dtype_override]
567-
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
564+
# Conver dtype override string to actual type.
565+
if args.quantization_mode in ["8da4w", "8da4w-gptq"]:
568566
dtype_override = DType["fp16"]
569567
else:
570-
dtype_override = None
568+
dtype_override = DType[args.dtype_override]
571569

572-
return (
573-
_load_llama_model(
574-
args.model,
575-
checkpoint=checkpoint_path,
576-
checkpoint_dir=checkpoint_dir,
577-
params_path=params_path,
578-
use_kv_cache=args.use_kv_cache,
579-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
580-
generate_full_logits=args.generate_full_logits,
581-
weight_type=weight_type,
582-
enable_dynamic_shape=args.enable_dynamic_shape,
583-
calibration_tasks=args.calibration_tasks,
584-
calibration_limit=args.calibration_limit,
585-
calibration_seq_length=args.calibration_seq_length,
586-
calibration_data=args.calibration_data,
587-
tokenizer_path=args.tokenizer_path,
588-
verbose=args.verbose,
589-
max_seq_len=args.max_seq_length,
590-
max_context_len=args.max_context_length,
591-
input_prune_map_path=args.input_prune_map,
592-
output_prune_map_path=args.output_prune_map,
593-
metadata_str=args.metadata,
594-
dtype_override=dtype_override,
595-
args=args,
596-
)
597-
.set_output_dir(output_dir_path)
598-
.source_transform(_get_source_transforms(args.model, dtype_override, args))
570+
edge_manager = _load_llama_model(
571+
args.model,
572+
checkpoint=checkpoint_path,
573+
checkpoint_dir=checkpoint_dir,
574+
params_path=params_path,
575+
use_kv_cache=args.use_kv_cache,
576+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
577+
generate_full_logits=args.generate_full_logits,
578+
weight_type=weight_type,
579+
enable_dynamic_shape=args.enable_dynamic_shape,
580+
calibration_tasks=args.calibration_tasks,
581+
calibration_limit=args.calibration_limit,
582+
calibration_seq_length=args.calibration_seq_length,
583+
calibration_data=args.calibration_data,
584+
tokenizer_path=args.tokenizer_path,
585+
verbose=args.verbose,
586+
max_seq_len=args.max_seq_length,
587+
max_context_len=args.max_context_length,
588+
input_prune_map_path=args.input_prune_map,
589+
output_prune_map_path=args.output_prune_map,
590+
metadata_str=args.metadata,
591+
dtype_override=dtype_override,
592+
args=args,
599593
)
594+
.set_output_dir(output_dir_path)
595+
.source_transform(_get_source_transforms(args.model, dtype_override, args))
596+
597+
# Override dtype of the model as specified by the user args.
598+
if dtype_override:
599+
assert isinstance(
600+
dtype_override, DType
601+
), "Override dtype needs to be of type <DType>"
602+
torch_dtype = dtype_override.to_torch_dtype()
603+
logging.info(f"model.to {torch_dtype}")
604+
edge_manager.model = edge_manager.model.to(dtype=torch_dtype)
605+
606+
return edge_manager
600607

601608

602609
def get_quantizer_and_quant_params(args):
@@ -971,38 +978,12 @@ def _load_llama_model(
971978
args=args,
972979
)
973980
)
974-
if dtype_override:
975-
assert isinstance(
976-
dtype_override, DType
977-
), "Override dtype needs to be of type <DType>"
978-
torch_dtype = dtype_override.to_torch_dtype()
979-
logging.info(f"model.to {torch_dtype}")
980-
model = model.to(dtype=torch_dtype)
981-
dtype = dtype_override
982-
else:
983-
state_dict = model.state_dict()
984-
dtype = state_dict[next(iter(state_dict))].dtype
985-
assert dtype in [
986-
torch.bfloat16,
987-
torch.float16,
988-
torch.float32,
989-
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
990-
logging.info(f"Loaded model with dtype={dtype}")
991-
992-
if dtype == torch.bfloat16:
993-
dtype = DType.bf16
994-
elif dtype == torch.float16:
995-
dtype = DType.fp16
996-
elif dtype == torch.float32:
997-
dtype = DType.fp32
998-
else:
999-
raise ValueError(f"Unsupported dtype {dtype}")
1000981

1001982
return LLMEdgeManager(
1002983
model=model,
1003984
modelname=modelname,
1004985
max_seq_len=model.max_seq_len,
1005-
dtype=dtype,
986+
dtype=dtype_override,
1006987
use_kv_cache=use_kv_cache,
1007988
generate_full_logits=generate_full_logits,
1008989
example_inputs=example_inputs,

examples/models/llama/model.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -171,70 +171,7 @@ def __init__(self, **kwargs):
171171
# Within the device="meta" context, tensors that are created do not carry data.
172172
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173173
with torch.device("meta"):
174-
self.model_ = Transformer(model_args)
175-
176-
if "int8" in str(checkpoint_path):
177-
print("Using int8 weight-only quantization!")
178-
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.source_transformation.quantize`
179-
from ..source_transformation.quantize import WeightOnlyInt8QuantHandler
180-
181-
simple_quantizer = WeightOnlyInt8QuantHandler(self.model_)
182-
self.model_ = simple_quantizer.convert_for_runtime()
183-
elif "8da4w" in str(checkpoint_path):
184-
print("Using int4 weight and int8 dynamic activation quantization!")
185-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
186-
187-
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
188-
self.model_
189-
)
190-
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
191-
print("Using SPIN quantization.")
192-
self._transform_for_pre_quantization(checkpoint, model_args)
193-
194-
from .source_transformation.pre_quantization import (
195-
sanitize_checkpoint_from_pre_quantization,
196-
)
197-
198-
sanitize_checkpoint_from_pre_quantization(checkpoint)
199-
elif hasattr(self.args, "use_qat") and self.args.use_qat:
200-
print("Using QAT quantization.")
201-
self._transform_for_pre_quantization(checkpoint, model_args)
202-
if hasattr(self.args, "use_lora") and self.args.use_lora:
203-
assert model_args.lora_args["rank"] == self.args.use_lora
204-
from .source_transformation.lora import (
205-
transform_linear_for_lora_after_quantization,
206-
)
207-
208-
self.model_ = transform_linear_for_lora_after_quantization(
209-
self.model_,
210-
checkpoint,
211-
self.args.use_lora,
212-
)
213-
214-
from .source_transformation.pre_quantization import (
215-
sanitize_checkpoint_from_pre_quantization,
216-
)
217-
218-
sanitize_checkpoint_from_pre_quantization(checkpoint)
219-
220-
if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
221-
from .source_transformation.attention_sink import enable_attention_sink
222-
223-
attention_sink_params = self.args.use_attention_sink.split(",")
224-
assert len(attention_sink_params) == 3
225-
sink_size = int(attention_sink_params[0])
226-
window_size = int(attention_sink_params[1])
227-
eviction_batch_size = int(attention_sink_params[2])
228-
229-
assert self.args.max_context_length == sink_size + window_size
230-
231-
self.model_ = enable_attention_sink(
232-
module=self.model_,
233-
params=model_args,
234-
sink_size=sink_size,
235-
window_size=window_size,
236-
eviction_batch_size=eviction_batch_size,
237-
)
174+
self.model_ = Transformer(model_args).to(dtype=self.dtype)
238175

239176
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
240177
# Because we are using device="meta", tensors do not have memory associated with them

examples/models/llama/source_transformation/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
611611
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
612612
)
613613
weight, scales, _ = dynamically_quantize_per_channel(
614-
mod.weight.float(),
614+
mod.weight,
615615
range_min,
616616
range_max,
617617
torch.int8,

0 commit comments

Comments
 (0)