Skip to content

Commit 00a7e60

Browse files
yeyu-nvidiakinjalpatel27AAnooshehrealAsmacjluo-nv
authored
remove DetachedEagleGPT model and handle all offline mode in the _DynamicEagleGPTModel (#321)
Signed-off-by: Ye Yu <[email protected]> Signed-off-by: Kinjal Patel <[email protected]> Signed-off-by: Asha Anoosheh <[email protected]> Signed-off-by: realAsma <[email protected]> Signed-off-by: Chenjie Luo <[email protected]> Signed-off-by: Riyad Islam <[email protected]> Signed-off-by: Yue <[email protected]> Signed-off-by: h-guo18 <[email protected]> Signed-off-by: Chenjie Luo <[email protected]> Signed-off-by: Keval Morabia <[email protected]> Signed-off-by: omrialmog <[email protected]> Signed-off-by: Jennifer Chen <[email protected]> Co-authored-by: kinjalpatel27 <[email protected]> Co-authored-by: Asha Anoosheh <[email protected]> Co-authored-by: realAsma <[email protected]> Co-authored-by: Chenjie Luo <[email protected]> Co-authored-by: Riyad Islam <[email protected]> Co-authored-by: yueshen2016 <[email protected]> Co-authored-by: h-guo18 <[email protected]> Co-authored-by: Chenjie Luo <[email protected]> Co-authored-by: Keval Morabia <[email protected]> Co-authored-by: omrialmog <[email protected]> Co-authored-by: Jenny Chen <[email protected]>
1 parent b895dc5 commit 00a7e60

File tree

3 files changed

+70
-446
lines changed

3 files changed

+70
-446
lines changed

modelopt/torch/speculative/eagle/conversion.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,21 @@
2424
from ..config import EagleConfig
2525

2626
EagleDMRegistry = _DMRegistryCls(prefix="Eagle") # global instance for the registry
27-
OfflineEagleDMRegistry = _DMRegistryCls(prefix="DetachedEagle") # global instance for the registry
2827

2928

3029
def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertReturnType:
3130
"""Convert the model to a eagle model as per `config`."""
3231
# initialize the true module if necessary
3332
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model
3433

35-
registry = OfflineEagleDMRegistry if config.eagle_offline else EagleDMRegistry
36-
3734
original_cls = type(model)
38-
if original_cls not in registry:
39-
for cls in registry._registry:
35+
if original_cls not in EagleDMRegistry:
36+
for cls in EagleDMRegistry._registry:
4037
if issubclass(original_cls, cls):
41-
registry.register({original_cls: "base_model_class"})(registry[cls])
38+
EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls])
4239
break
4340

44-
eagle_model = registry.convert(model)
41+
eagle_model = EagleDMRegistry.convert(model)
4542
eagle_model.modify(
4643
eagle_offline=config.eagle_offline,
4744
eagle_hidden_state_distillation=config.eagle_hidden_state_distillation,

0 commit comments

Comments
 (0)