Skip to content

Commit 7b501a9

Browse files
authored
Update classifier and semantic deduplication PyTests (#697)
* update semdedup checks to be more robust Signed-off-by: Sarah Yurick <[email protected]> * fix ruff Signed-off-by: Sarah Yurick <[email protected]> * fix tolist Signed-off-by: Sarah Yurick <[email protected]> * releax rounding for prompt classifier test Signed-off-by: Sarah Yurick <[email protected]> --------- Signed-off-by: Sarah Yurick <[email protected]>
1 parent 65e2fbc commit 7b501a9

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

tests/test_classifiers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,16 @@ def test_prompt_task_complexity_classifier(gpu_client) -> None: # noqa: ANN001,
273273
# Rounded values to account for floating point errors
274274
result_pred["constraint_ct"] = round(result_pred["constraint_ct"], 2)
275275
expected_pred["constraint_ct"] = round(expected_pred["constraint_ct"], 2)
276-
result_pred["contextual_knowledge"] = round(result_pred["contextual_knowledge"], 3)
277-
expected_pred["contextual_knowledge"] = round(expected_pred["contextual_knowledge"], 3)
276+
result_pred["contextual_knowledge"] = round(result_pred["contextual_knowledge"], 2)
277+
expected_pred["contextual_knowledge"] = round(expected_pred["contextual_knowledge"], 2)
278278
result_pred["creativity_scope"] = round(result_pred["creativity_scope"], 2)
279279
expected_pred["creativity_scope"] = round(expected_pred["creativity_scope"], 2)
280-
result_pred["prompt_complexity_score"] = round(result_pred["prompt_complexity_score"], 3)
281-
expected_pred["prompt_complexity_score"] = round(expected_pred["prompt_complexity_score"], 3)
280+
result_pred["domain_knowledge"] = round(result_pred["domain_knowledge"], 2)
281+
expected_pred["domain_knowledge"] = round(expected_pred["domain_knowledge"], 2)
282+
result_pred["prompt_complexity_score"] = round(result_pred["prompt_complexity_score"], 2)
283+
expected_pred["prompt_complexity_score"] = round(expected_pred["prompt_complexity_score"], 2)
284+
result_pred["reasoning"] = round(result_pred["reasoning"], 2)
285+
expected_pred["reasoning"] = round(expected_pred["reasoning"], 2)
282286
result_pred["task_type_prob"] = round(result_pred["task_type_prob"], 2)
283287
expected_pred["task_type_prob"] = round(expected_pred["task_type_prob"], 2)
284288

tests/test_semdedup.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import os
1516
import random
17+
from itertools import product
1618
from pathlib import Path
1719
from typing import TYPE_CHECKING, Literal
1820

@@ -145,19 +147,35 @@ def test_sem_dedup(
145147
# Correctly returns the original dataset with no duplicates removed
146148
result = sem_duplicates(dedup_data)
147149
result_df = result.df.compute()
148-
docs_to_remove = [1, 100]
149-
if id_col_type == "str":
150-
docs_to_remove = list(map(str, docs_to_remove))
151150

152151
if not perform_removal:
153-
expected_df = cudf.Series(docs_to_remove, name="id", dtype=id_col_type)
154-
assert_eq(result_df["id"].sort_values(), expected_df, check_index=False)
152+
first_doc_to_remove = [1, 2, 3, 4]
153+
second_doc_to_remove = [100, 200, 300]
154+
# Generate all possible combinations of documents to remove
155+
expected_series_list = [
156+
cudf.Series([a, b], name="id", dtype=id_col_type).sort_values().reset_index(drop=True)
157+
for a, b in product(first_doc_to_remove, second_doc_to_remove)
158+
]
159+
160+
result_series = result_df["id"].sort_values().reset_index(drop=True)
161+
assert any(result_series.equals(expected_series) for expected_series in expected_series_list)
155162
else:
156-
assert_eq(
157-
result_df,
158-
dedup_data.df[~dedup_data.df["id"].isin(docs_to_remove)],
159-
check_index=False,
160-
)
163+
if id_col_type == "int":
164+
first_doc_to_keep = {1, 2, 3, 4}
165+
second_doc_to_keep = {100, 200, 300}
166+
else:
167+
first_doc_to_keep = {"1", "2", "3", "4"}
168+
second_doc_to_keep = {"100", "200", "300"}
169+
170+
result_ids = set(result_df["id"].to_arrow().to_pylist())
171+
172+
# Intersection of the sets
173+
num_kept_from_first = len(result_ids & first_doc_to_keep)
174+
num_kept_from_second = len(result_ids & second_doc_to_keep)
175+
176+
assert len(result_ids) == 5 # noqa: PLR2004
177+
assert num_kept_from_first == 3 # noqa: PLR2004
178+
assert num_kept_from_second == 2 # noqa: PLR2004
161179

162180
@pytest.mark.parametrize("n_clusters", [2, 3])
163181
@pytest.mark.parametrize("perform_removal", [True, False])

0 commit comments

Comments
 (0)