diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index eca83b3df..0b2a886f5 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -90,7 +90,7 @@ def dict_to_config( fp16=fp16, bf16=bf16, params_dtype=getattr(torch, architecture_config["torch_dtype"]), - pipeline_dtype=None, + pipeline_dtype=getattr(torch, architecture_config["torch_dtype"]), num_layers=architecture_config.get("num_hidden_layers"), hidden_size=architecture_config.get("hidden_size"), ffn_hidden_size=architecture_config.get("intermediate_size"),