-
Notifications
You must be signed in to change notification settings - Fork 51
Description
Hi,
I am trying to fine-tune SGPT-2.7B-weightedmean-msmarco-specb-bitfit with unlabeled dataset using TSDAE approach. Getting this error:
Type Error: forward() got an unexpected keyword argument 'encoder_hidden_states'.
Please help. Thanks!
Stack Trace:
File ~.../sentence_transformers/losses/DenoisingAutoEncoderLoss.py:111, in DenoisingAutoEncoderLoss.forward(self, sentence_features, labels)
108 label_ids = target_features['input_ids'][:, 1:]
110 # Decode
--> 111 decoder_outputs = self.decoder(
112 input_ids=decoder_input_ids,
113 inputs_embeds=None,
114 attention_mask=None,
115 encoder_hidden_states=reps[:, None], # (bsz, hdim) -> (bsz, 1, hdim)
116 encoder_attention_mask=source_features['attention_mask'][:, 0:1],
117 labels=None,
118 return_dict=None,
119 use_cache=False
120 )
122 # Calculate loss
123 lm_logits = decoder_outputs[0]
File .../dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'