|
8 | 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
9 | 9 | # See the License for the specific language governing permissions and |
10 | 10 | # limitations under the License. |
11 | | -from unittest.mock import Mock |
| 11 | +from unittest.mock import Mock, patch |
12 | 12 |
|
13 | 13 | import pytest |
14 | 14 | from presidio_analyzer import EntityRecognizer, RecognizerResult |
|
17 | 17 | from metadata.pii.algorithms.presidio_utils import ( |
18 | 18 | apply_confidence_threshold, |
19 | 19 | build_analyzer_engine, |
| 20 | + load_nlp_engine, |
20 | 21 | set_presidio_logger_level, |
21 | 22 | ) |
22 | 23 | from metadata.pii.algorithms.tags import PIITag |
@@ -138,3 +139,88 @@ def test_threshold_of_zero_returns_all_results(self, mock_recognizer): |
138 | 139 | ) |
139 | 140 |
|
140 | 141 | assert len(results) == 3 |
| 142 | + |
| 143 | + |
| 144 | +@patch("metadata.pii.algorithms.presidio_utils._load_spacy_model") |
| 145 | +@patch("metadata.pii.algorithms.presidio_utils.SpacyNlpEngine") |
| 146 | +class TestLoadNlpEngine: |
| 147 | + @staticmethod |
| 148 | + def setup_method(): |
| 149 | + """Clear the cache before each test""" |
| 150 | + load_nlp_engine.cache_clear() |
| 151 | + |
| 152 | + @staticmethod |
| 153 | + def teardown_method(): |
| 154 | + """Clear the cache after each test""" |
| 155 | + load_nlp_engine.cache_clear() |
| 156 | + |
| 157 | + def test_returns_same_instance_for_same_parameters( |
| 158 | + self, mock_spacy_engine_class, mock_load_spacy |
| 159 | + ): |
| 160 | + """Test that calling load_nlp_engine with same parameters returns same instance""" |
| 161 | + mock_engine = Mock() |
| 162 | + mock_spacy_engine_class.return_value = mock_engine |
| 163 | + |
| 164 | + engine1 = load_nlp_engine(model_name="en_core_web_sm", supported_language="en") |
| 165 | + engine2 = load_nlp_engine(model_name="en_core_web_sm", supported_language="en") |
| 166 | + |
| 167 | + assert engine1 is engine2 |
| 168 | + assert mock_spacy_engine_class.call_count == 1 |
| 169 | + assert mock_load_spacy.call_count == 1 |
| 170 | + |
| 171 | + def test_returns_different_instances_for_different_model_names( |
| 172 | + self, mock_spacy_engine_class, mock_load_spacy |
| 173 | + ): |
| 174 | + """Test that different model names result in different instances""" |
| 175 | + mock_engine1 = Mock() |
| 176 | + mock_engine2 = Mock() |
| 177 | + mock_spacy_engine_class.side_effect = [mock_engine1, mock_engine2] |
| 178 | + |
| 179 | + engine1 = load_nlp_engine(model_name="en_core_web_sm", supported_language="en") |
| 180 | + engine2 = load_nlp_engine(model_name="en_core_web_md", supported_language="en") |
| 181 | + |
| 182 | + assert engine1 is not engine2 |
| 183 | + assert mock_spacy_engine_class.call_count == 2 |
| 184 | + assert mock_load_spacy.call_count == 2 |
| 185 | + |
| 186 | + def test_returns_different_instances_for_different_languages( |
| 187 | + self, mock_spacy_engine_class, mock_load_spacy |
| 188 | + ): |
| 189 | + """Test that different languages result in different instances""" |
| 190 | + mock_engine1 = Mock() |
| 191 | + mock_engine2 = Mock() |
| 192 | + mock_spacy_engine_class.side_effect = [mock_engine1, mock_engine2] |
| 193 | + |
| 194 | + engine1 = load_nlp_engine(model_name="en_core_web_sm", supported_language="en") |
| 195 | + engine2 = load_nlp_engine(model_name="en_core_web_sm", supported_language="fr") |
| 196 | + |
| 197 | + assert engine1 is not engine2 |
| 198 | + assert mock_spacy_engine_class.call_count == 2 |
| 199 | + |
| 200 | + def test_cache_persists_across_multiple_calls( |
| 201 | + self, mock_spacy_engine_class, mock_load_spacy |
| 202 | + ): |
| 203 | + """Test that cache works correctly across multiple calls""" |
| 204 | + mock_engine = Mock() |
| 205 | + mock_spacy_engine_class.return_value = mock_engine |
| 206 | + |
| 207 | + engine1 = load_nlp_engine(model_name="en_core_web_sm", supported_language="en") |
| 208 | + engine2 = load_nlp_engine(model_name="en_core_web_sm", supported_language="en") |
| 209 | + engine3 = load_nlp_engine(model_name="en_core_web_sm", supported_language="en") |
| 210 | + |
| 211 | + assert engine1 is engine2 is engine3 |
| 212 | + assert mock_spacy_engine_class.call_count == 1 |
| 213 | + assert mock_load_spacy.call_count == 1 |
| 214 | + |
| 215 | + def test_uses_default_parameters_when_not_provided( |
| 216 | + self, mock_spacy_engine_class, mock_load_spacy |
| 217 | + ): |
| 218 | + """Test that default parameters work correctly with caching""" |
| 219 | + mock_engine = Mock() |
| 220 | + mock_spacy_engine_class.return_value = mock_engine |
| 221 | + |
| 222 | + engine1 = load_nlp_engine() |
| 223 | + engine2 = load_nlp_engine() |
| 224 | + |
| 225 | + assert engine1 is engine2 |
| 226 | + assert mock_spacy_engine_class.call_count == 1 |
0 commit comments