Skip to content

Commit f22a09b

Browse files
committed
Pipe in dtype correctly to model.py
1 parent bfe8b06 commit f22a09b

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ def _load_llama_model(
983983
enable_dynamic_shape=enable_dynamic_shape,
984984
input_prune_map_path=input_prune_map_path,
985985
output_prune_map_path=output_prune_map_path,
986+
dtype=dtype_override.to_torch_dtype(),
986987
args=args,
987988
)
988989
)

examples/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, **kwargs):
5454
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5555
self.max_seq_len = kwargs.get("max_seq_len", 128)
5656
self.max_context_len = kwargs.get("max_context_len", 128)
57-
self.dtype = kwargs.get("dtype_override", None)
57+
self.dtype = kwargs.get("dtype", None)
5858
self.args = kwargs.get("args", None)
5959

6060
assert (

0 commit comments

Comments
 (0)