Skip to content

Commit 9da614d

Browse files
authored
[evaluation] use thread lock for nltk data download to avoid race condition (#37487)
* use threadlock for nltk data download * update
1 parent 8033646 commit 9da614d

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/utils.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313

1414
from typing import List
1515

16+
import threading
1617
import numpy as np
18+
import nltk
19+
20+
_nltk_data_download_lock = threading.Lock()
1721

1822

1923
def get_harm_severity_level(harm_score: int) -> str:
@@ -38,21 +42,24 @@ def get_harm_severity_level(harm_score: int) -> str:
3842
return np.nan
3943

4044

41-
def nltk_tokenize(text: str) -> List[str]:
42-
"""Tokenize the input text using the NLTK tokenizer."""
45+
def ensure_nltk_data_downloaded():
46+
"""Download NLTK data packages if not already downloaded."""
47+
with _nltk_data_download_lock:
48+
try:
49+
from nltk.tokenize.nist import NISTTokenizer
50+
except LookupError:
51+
nltk.download("perluniprops")
52+
nltk.download("punkt")
53+
nltk.download("punkt_tab")
4354

44-
import nltk
4555

46-
try:
47-
from nltk.tokenize.nist import NISTTokenizer
48-
except LookupError:
49-
nltk.download("perluniprops")
50-
nltk.download("punkt")
51-
nltk.download("punkt_tab")
52-
from nltk.tokenize.nist import NISTTokenizer
56+
def nltk_tokenize(text: str) -> List[str]:
57+
"""Tokenize the input text using the NLTK tokenizer."""
58+
ensure_nltk_data_downloaded()
5359

5460
if not text.isascii():
5561
# Use NISTTokenizer for international tokenization
62+
from nltk.tokenize.nist import NISTTokenizer
5663
tokens = NISTTokenizer().international_tokenize(text)
5764
else:
5865
# By default, use NLTK word tokenizer

sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@
3333
@pytest.mark.usefixtures("recording_injection", "recorded_test")
3434
@pytest.mark.localtest
3535
class TestBuiltInEvaluators:
36-
@pytest.mark.skipif(
37-
condition=platform.python_implementation() == "PyPy",
38-
reason="Temporary skip to merge 37201, will re-enable in subsequent pr",
39-
)
4036
def test_math_evaluator_bleu_score(self):
4137
eval_fn = BleuScoreEvaluator()
4238
score = eval_fn(

0 commit comments

Comments
 (0)