Skip to content

Commit eeb30f4

Browse files
Darinochkatruff4ut
andauthored
feat: added eps for zeros (#35)
* feat: added eps for zeros * Add embedding normalization --------- Co-authored-by: Egor Sergeenko <[email protected]>
1 parent 075b8de commit eeb30f4

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

autointent/context/embedder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,6 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
7171
)
7272
if self.max_length is not None:
7373
self.embedding_model.max_seq_length = self.max_length
74-
return self.embedding_model.encode(utterances, convert_to_numpy=True, batch_size=self.batch_size) # type: ignore[return-value]
74+
return self.embedding_model.encode(
75+
utterances, convert_to_numpy=True, batch_size=self.batch_size, normalize_embeddings=True,
76+
) # type: ignore[return-value]

autointent/metrics/scoring.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __call__(self, labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> floa
2424
...
2525

2626

27-
def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
27+
def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE, eps: float = 1e-10) -> float:
2828
"""
2929
supports multiclass and multilabel
3030
@@ -45,9 +45,10 @@ def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE)
4545
where `s[i,c]` is a predicted score of `i`th utterance having ground truth label `c`
4646
"""
4747
labels_array, scores_array = transform(labels, scores)
48+
scores_array[scores_array == 0] = eps
4849

4950
if np.any((scores_array <= 0) | (scores_array > 1)):
50-
msg = "One or more scores are not from [0,1]. It is incompatible with `scoring_log_likelihood` metric"
51+
msg = "One or more scores are not from (0,1]. It is incompatible with `scoring_log_likelihood` metric"
5152
logger.error(msg)
5253
raise ValueError(msg)
5354

0 commit comments

Comments
 (0)