Skip to content

Commit 8a07376

Browse files
authored
set pipeline_dtype default value to params_dtype in megatron_eagle TransformerConfig (#291)
Signed-off-by: Ye Yu <[email protected]>
1 parent 2b52759 commit 8a07376

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def dict_to_config(
9090
fp16=fp16,
9191
bf16=bf16,
9292
params_dtype=getattr(torch, architecture_config["torch_dtype"]),
93-
pipeline_dtype=None,
93+
pipeline_dtype=getattr(torch, architecture_config["torch_dtype"]),
9494
num_layers=architecture_config.get("num_hidden_layers"),
9595
hidden_size=architecture_config.get("hidden_size"),
9696
ffn_hidden_size=architecture_config.get("intermediate_size"),

0 commit comments

Comments
 (0)