diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index a033386d..ecc0c37d 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -24,7 +24,6 @@ from ..config import EagleConfig EagleDMRegistry = _DMRegistryCls(prefix="Eagle") # global instance for the registry -OfflineEagleDMRegistry = _DMRegistryCls(prefix="DetachedEagle") # global instance for the registry def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertReturnType: @@ -32,16 +31,14 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu # initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model - registry = OfflineEagleDMRegistry if config.eagle_offline else EagleDMRegistry - original_cls = type(model) - if original_cls not in registry: - for cls in registry._registry: + if original_cls not in EagleDMRegistry: + for cls in EagleDMRegistry._registry: if issubclass(original_cls, cls): - registry.register({original_cls: "base_model_class"})(registry[cls]) + EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls]) break - eagle_model = registry.convert(model) + eagle_model = EagleDMRegistry.convert(model) eagle_model.modify( eagle_offline=config.eagle_offline, eagle_hidden_state_distillation=config.eagle_hidden_state_distillation, diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index c2447367..2a0e63a3 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -52,7 +52,7 @@ from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint from packaging.version import Version -from ..eagle.conversion import EagleDMRegistry, OfflineEagleDMRegistry +from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel from ..utils import ( AcceptanceRateValidation, @@ -745,6 +745,10 @@ def modify( eagle_architecture_config=eagle_architecture_config, ) + # sequence_parallel is not used in offline eagle + if self.eagle_offline: + self.config.sequence_parallel = False + self.eagle_config = dict_to_config( eagle_architecture_config, self.config.use_cpu_initialization, @@ -760,6 +764,7 @@ def modify( # Use default aux_hidden_state layers if use_aux_hidden_state is True # but no layer id is given + # layer ids are not used in offline eagle, but we need to set this to have correct fc_input_size_multiplier if ( self.eagle_config.use_aux_hidden_state and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 @@ -810,6 +815,8 @@ def modify( self.eagle_config.eagle_aux_hidden_state_layer_ids ) eagle_config.use_mtp_layernorm = self.eagle_config.use_mtp_layernorm + eagle_config.draft_vocab_size = self.eagle_config.draft_vocab_size + eagle_config.has_lm_head = self.eagle_config.has_lm_head self.eagle_module = EagleModule( eagle_config, self.rotary_pos_emb, @@ -837,24 +844,25 @@ def modify( self.kld = logits_kld_loss def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True): - """When _aux_hidden_states is not empty, then this is EAGLE-3. + """When _aux_hidden_states is not empty for online, then this is EAGLE-3. Args: hidden_states: last hidden_states apply_fc: whether to apply EAGLE3 fc """ - if len(self._aux_hidden_states) == 0: - return hidden_states + if not self.eagle_offline: + if len(self._aux_hidden_states) == 0: + return hidden_states - # [s / TP, b, len(self._aux_hidden_states) * h] - aux_hidden_states = torch.cat(self._aux_hidden_states, dim=-1) - self._aux_hidden_states.clear() + # [s / TP, b, len(self._aux_hidden_states) * h] + hidden_states = torch.cat(self._aux_hidden_states, dim=-1) + self._aux_hidden_states.clear() if apply_fc: # [s / TP, b, 3h] -> [s / TP, b, h] - return self.eagle_module.fc(aux_hidden_states)[0] + return self.eagle_module.fc(hidden_states)[0] else: - return aux_hidden_states + return hidden_states def _get_eagle_module_inputs( self, @@ -1183,34 +1191,51 @@ def forward( return_eagle_inputs: bool = False, **kwargs, ) -> torch.Tensor: - if input_ids is not None and (position_ids is None or attention_mask is None): + if position_ids is None or attention_mask is None: attention_mask, position_ids = get_default_attention_mask_and_position_ids(input_ids) - # When return_eagle_inputs is True, return decoder_input_for_eagle. - # When LLM, decoder_input_for_eagle is just the text embeddings. However, when VLM - # decoder_input_for_eagle will also contain projected image/video embeddings. - hidden_states, decoder_input_for_eagle = self._base_model_forward( - input_ids, - position_ids, - attention_mask, - decoder_input, - inference_params, - packed_seq_params, - extra_block_kwargs, - return_eagle_inputs=return_eagle_inputs, - ) + if self.eagle_offline: + # aux_hidden_states and hidden_states are provided for offline eagle + # _base_model_forward is skipped + if return_eagle_inputs: + raise ValueError("return_eagle_inputs is unsupported in EAGLE offline mode.") + aux_hidden_states = kwargs.get("aux_hidden_states") + hidden_states = kwargs.get("hidden_states") + if aux_hidden_states is None or hidden_states is None: + raise ValueError( + "EAGLE offline mode requires kwargs: aux_hidden_states=[s,b,k*h], " + "hidden_states=[s,b,h]." + ) + else: + # When return_eagle_inputs is True, return decoder_input_for_eagle. + # For LLM, decoder_input_for_eagle is just the text embeddings. However, for VLM + # decoder_input_for_eagle will also contain projected image/video embeddings. + hidden_states, decoder_input_for_eagle = self._base_model_forward( + input_ids, + position_ids, + attention_mask, + decoder_input, + inference_params, + packed_seq_params, + extra_block_kwargs, + return_eagle_inputs=return_eagle_inputs, + ) - # Typically, this is only the case when PP > 1. - if not self.post_process: - return hidden_states + # Typically, this is only the case when PP > 1. + if not self.post_process: + return hidden_states output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) + if self.eagle_offline: + eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( + aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state + ) # If EAGLE-3, aux_hidden_states are gathered by the forward_hook - if return_eagle_inputs: + elif return_eagle_inputs: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( hidden_states, apply_fc=False ) @@ -1232,7 +1257,7 @@ def forward( hidden_states, apply_fc=True ) - # Either inference or calibration mode, we want to make sure all weights have been exercised. + # In calibration mode, we want to make sure all weights have been exercised. # This makes sure all quantized weights have amax calibrated if inference_params is None or self.calibration_mode: eagle_inputs_0 = self._get_eagle_module_inputs( @@ -1254,6 +1279,17 @@ def forward( # all eagle weights have been exercised for quantization calibration purpose. if labels is None: return logits_sbh.transpose(0, 1).contiguous() + elif labels.shape[1] == input_ids.shape[1] - 1: + # For offline training, labels may be 1 token shorter than input_ids. + # We will just pad a 0 to the labels to make the seq_len the same as + # input_ids. This will introduce a small error in training if logit_distillation + # is False, and testing accuracy is wrong for the last token. + right_token_pad = torch.zeros( + (labels.shape[0], 1), + dtype=labels.dtype, + device=labels.device, + ) + labels = torch.cat((labels, right_token_pad), dim=-1) # If eagle_freeze_base_model is set to True, # the base model is frozen . @@ -1305,7 +1341,7 @@ def forward( packed_seq_params=packed_seq_params, **(extra_block_kwargs or {}), ) - eagle_logits_1 = eagle_logits_2x[labels.shape[1] :, :, :] + eagle_logits_1 = eagle_logits_2x[-labels.shape[1] :, :, :] loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1) # [b, s - 2] @@ -1713,408 +1749,6 @@ def pseudo_speculative_generate( return base_token, draft_tokens -@OfflineEagleDMRegistry.register({GPTModel: "megatron.core.models.gpt.GPTModel"}) -class _DetachedEagleGPTModel(_DynamicEagleGPTModel): - """A wrapper for detached Eagle module.""" - - def modify( - self, - eagle_offline, - eagle_hidden_state_distillation, - eagle_self_logit_distillation, - eagle_freeze_base_model, - eagle_report_acc, - eagle_reuse_base_decoder, - eagle_loss_decay_factor, - eagle_architecture_config, - ): - super(_DynamicEagleGPTModel, self).modify( - eagle_offline=eagle_offline, - eagle_hidden_state_distillation=eagle_hidden_state_distillation, - eagle_self_logit_distillation=eagle_self_logit_distillation, - eagle_freeze_base_model=eagle_freeze_base_model, - eagle_report_acc=eagle_report_acc, - eagle_reuse_base_decoder=eagle_reuse_base_decoder, - eagle_loss_decay_factor=eagle_loss_decay_factor, - eagle_architecture_config=eagle_architecture_config, - ) - - # Freeze all parameters - if self.eagle_freeze_base_model: - for name, param in self.named_parameters(): - param.requires_grad = False - - self.eagle_config = dict_to_config( - eagle_architecture_config, - self.config.use_cpu_initialization, - self.config.fp16, - self.config.bf16, - ) - - assert not eagle_reuse_base_decoder, ( - "_DetachedEagleGPTModel does not have a base model so eagle_reuse_base_decoder must be False!" - ) - - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - assert eagle_self_logit_distillation, ( - "Only logit distillation is supported when draft_vocab_size != vocab_size!" - ) - - # Use default aux_hidden_state layers if use_aux_hidden_state is True - # but no layer id is given - # layer ids are not used in detached eagle, but we need to set this to have correct fc_input_size_multiplier - if ( - self.eagle_config.use_aux_hidden_state - and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 - ): - self._set_default_aux_hidden_state_layers() - - # Only the last PP stage has the additional projection and decoder layer. - # This is to simplify the export. - if self.post_process: - rotary_pos_emb = RotaryEmbedding( - kv_channels=self.eagle_config.kv_channels, - rotary_percent=self.eagle_config.rotary_percent, - rotary_interleaved=False, - seq_len_interpolation_factor=None, - rotary_base=self.eagle_config.rotary_base, - rope_scaling=self.eagle_config.rope_scaling, - rope_scaling_factor=self.eagle_config.rope_scaling_factor, - use_cpu_initialization=self.eagle_config.use_cpu_initialization, - ) - - self.eagle_module = EagleModule( - self.eagle_config, - rotary_pos_emb, - bias=False, - ) - - # Eagle loss functions - self.kld = logits_kld_loss - - def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True): - if apply_fc: - # [s / TP, b, 3h] -> [s / TP, b, h] - return self.eagle_module.fc(hidden_states)[0] - else: - return hidden_states - - def _get_detached_eagle_module_inputs( - self, - input_ids: torch.Tensor, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - features: torch.Tensor | None = None, - ): - """Getting EAGLE module inputs.""" - b = hidden_states.shape[1] - h = hidden_states.shape[2] - - # [b, 1] - id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device) - padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) - - rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) - - attn_mask = attention_mask.clone().detach() - attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] - attn_mask[:, :, -1, :] = True - attn_mask[:, :, :, -1] = True - - eagle_inputs = {} - - assert self.eagle_config.parallel_draft_step == 1, ( - "Detached Eagle module does not support parallel draft yet!" - ) - if features is None: - eagle_inputs["input_ids"] = padded_input_ids - eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = attn_mask - eagle_inputs["position_ids"] = position_ids - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - elif features.shape[0] == hidden_states.shape[0]: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids), - dim=-1, - ) - eagle_inputs["hidden_states"] = torch.cat( - ( - hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - features[:-1, :, :], - ), - dim=0, - ) - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 2) - eagle_inputs["position_ids"] = torch.cat((position_ids, position_ids), dim=-1) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=0) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - elif features.shape[0] == hidden_states.shape[0] * 2: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, - ) - eagle_inputs["hidden_states"] = torch.cat( - ( - hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - features[:-1, :, :], - ), - dim=0, - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 3) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - else: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, - ) - eagle_inputs["hidden_states"] = torch.cat( - ( - hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - features[:-1, :, :], - ), - dim=0, - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 4) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids, position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - - eagle_inputs["embedding"] = self.embedding( - input_ids=eagle_inputs["input_ids"], - position_ids=eagle_inputs["position_ids"], - ) - - return eagle_inputs - - def forward( - self, - input_ids: torch.Tensor = None, - position_ids: torch.Tensor = None, - attention_mask: torch.Tensor = None, - decoder_input: torch.Tensor = None, - labels: torch.Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict | None = None, - return_eagle_inputs: bool = False, # Not used in Detached Eagle - **kwargs, - ) -> torch.Tensor: - assert "aux_hidden_states" in kwargs, ( - "aux_hidden_states is required as input to _DetachedEagleGPTModel" - ) - assert "hidden_states" in kwargs, ( - "hidden_states is required as input to _DetachedEagleGPTModel" - ) - aux_hidden_states = kwargs.get("aux_hidden_states") - hidden_states = kwargs.get("hidden_states") - - # Note: labels is 1 token shorter than logits in detached mode - - if position_ids is None or attention_mask is None: - attention_mask, position_ids = get_default_attention_mask_and_position_ids(input_ids) - - eagle_module_input_hidden_states = self._get_eagle_input_hidden_states(aux_hidden_states) - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) - - eagle_inputs_0 = self._get_detached_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - _, eagle_logits_0, eagle_hidden_states_0_pre_norm = self._eagle_forward( - eagle_inputs_0, - None, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - loss = torch.zeros(input_ids.shape).to(input_ids.device) - - loss_0 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_0) - loss[:, 1:] += self.eagle_loss_decay_factor * loss_0 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[:-2, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 1st Top-1: {acc}", - flush=True, - ) - - # Second round of EAGLE loss - eagle_inputs_1 = self._get_detached_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_0_pre_norm, - ) - - _, eagle_logits_2x, eagle_hidden_states_2x_pre_norm = self._eagle_forward( - eagle_inputs_1, - None, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - eagle_logits_1 = eagle_logits_2x[logits_sbh.shape[0] :, :, :] - - loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1) - # [b, s - 2] - loss_1 = loss_1[:, 1:] - loss[:, 2:] += self.eagle_loss_decay_factor**2 * loss_1 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_1[1:-2, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 2:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 2nd Top-1: {acc}", - flush=True, - ) - - # Third EAGLE loss - eagle_inputs_2 = self._get_detached_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_2x_pre_norm, - ) - - _, eagle_logits_3x, eagle_hidden_states_3x_pre_norm = self._eagle_forward( - eagle_inputs_2, - None, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - eagle_logits_2 = eagle_logits_3x[-logits_sbh.shape[0] :, :, :] - - loss_2 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_2) - # [b, s - 3] - loss_2 = loss_2[:, 2:] - loss[:, 3:] += self.eagle_loss_decay_factor**3 * loss_2 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_2[2:-2, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 3:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 3rd Top-1: {acc}", - flush=True, - ) - - # Forth EAGLE loss - eagle_inputs_3 = self._get_detached_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_3x_pre_norm, - ) - - _, eagle_logits_4x, eagle_hidden_states_4x_pre_norm = self._eagle_forward( - eagle_inputs_3, - None, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - eagle_logits_3 = eagle_logits_4x[-logits_sbh.shape[0] :, :, :] - - loss_3 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_3) - # [b, s - 4] - loss_3 = loss_3[:, 3:] - loss[:, 4:] += self.eagle_loss_decay_factor**4 * loss_3 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_3[3:-2, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 4:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 4th Top-1: {acc}", - flush=True, - ) - - return loss - - class MegatronARValidation(AcceptanceRateValidation): """This is the subclass for megatron model AR validation.""" diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index e1da326d..27f2011e 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -46,7 +46,7 @@ from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput -from ..eagle.conversion import EagleDMRegistry, OfflineEagleDMRegistry +from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask from ..medusa.conversion import MedusaDMRegistry @@ -1141,13 +1141,6 @@ def pseudo_speculative_generate( return base_token, draft_tokens -@OfflineEagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) -class DetachedHFEagleModel(HFEagleModel): - """A wrapper for detached Eagle module.""" - - # TODO: Implement DetachedHFEagleModel class for offline eagle. - - class HFARValidation(AcceptanceRateValidation): """This is the subclass for HF model AR validation."""