Skip to content

Commit 53a320e

Browse files
authored
Fix: fasttext predict call for numpy>2 (#1482)
Signed-off-by: Ayush Dattagupta <ayushdg95@gmail.com>
1 parent 90ce791 commit 53a320e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

nemo_curator/stages/text/filters/fasttext_filter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def score_document(self, text: str) -> float:
4444
model = self._fasttext_quality_filter_model
4545

4646
text = text.replace("\n", " ").replace("__label__", " ")
47-
pred = model.predict(text)
48-
document_score = pred[1][0]
49-
if pred[0][0] != self._label:
47+
label, score = model.predict([text])
48+
document_score = score[0][0].item()
49+
if label[0][0] != self._label:
5050
document_score = 1 - document_score
5151

5252
return document_score
@@ -78,9 +78,9 @@ def score_document(self, text: str) -> list[float | str]:
7878
model = self._fasttext_langid_model
7979

8080
pp = text.strip().replace("\n", " ")
81-
label, score = model.predict(pp, k=1)
82-
score = score[0]
83-
lang_code = label[0][-2:].upper()
81+
label, score = model.predict([pp], k=1)
82+
score = score[0][0].item()
83+
lang_code = label[0][0][-2:].upper()
8484

8585
# Need to convert it to a string to allow backend conversions
8686
return str([score, lang_code])

0 commit comments

Comments
 (0)