Skip to content

Commit 143f41e

Browse files
committed
update data generation tests
Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
1 parent 19dd7cf commit 143f41e

File tree

5 files changed

+39
-124
lines changed

5 files changed

+39
-124
lines changed

.pylintrc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,9 @@ disable=raw-checker-failed,
448448
abstract-method,
449449
wrong-import-order,
450450
line-too-long,
451-
logging-fstring-interpolation
451+
logging-fstring-interpolation,
452+
# This is being set off by our deprecation warnings
453+
duplicate-code
452454

453455
# Enable the message, report, category or checker with the given id(s). You can
454456
# either give multiple identifier separated by comma (,) or put this option

src/instructlab/sdg/datamixing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ def _create_mixed_dataset(self, num_proc):
192192
)
193193

194194
# assert that the dataset only has the allowed columns
195-
assert set(mixed_ds.column_names) == set(ALLOWED_COLS), (
196-
"Dataset has invalid columns"
197-
)
195+
assert set(mixed_ds.column_names) == set(
196+
ALLOWED_COLS
197+
), "Dataset has invalid columns"
198198
return mixed_ds
199199

200200
def add_dataset(self, path, sampling_size):

src/instructlab/sdg/mix_data.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

tests/test_datamixing.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
DataMixer,
1818
Recipe,
1919
_add_extra_contexts_to_samples,
20+
_conv_pretrain,
21+
_create_auxiliary_dataset,
2022
_create_phase07_ds,
2123
_create_phase10_ds,
22-
_create_auxiliary_dataset,
23-
_conv_pretrain,
2424
)
2525

2626
# We mock out the actual things that use num_procs anyway, but just
@@ -269,17 +269,17 @@ def test_phase07_creation(mock_auxiliary_dataset):
269269

270270
# Check if Phase 0.7 contains knowledge and auxiliary datasets
271271
expected_phase07_size = len(knowledge_dataset) + len(auxiliary_dataset)
272-
assert len(phase07_ds) == expected_phase07_size, (
273-
"Phase 0.7 should contain knowledge and auxiliary datasets."
274-
)
272+
assert (
273+
len(phase07_ds) == expected_phase07_size
274+
), "Phase 0.7 should contain knowledge and auxiliary datasets."
275275

276276
# Verify that the content from all datasets is present in Phase 0.7
277277
auxiliary_ids = {item["id"] for item in auxiliary_dataset}
278278
phase07_ids = {item["id"] for item in phase07_ds}
279279

280-
assert auxiliary_ids.issubset(phase07_ids), (
281-
"Phase 0.7 should include all auxiliary dataset entries."
282-
)
280+
assert auxiliary_ids.issubset(
281+
phase07_ids
282+
), "Phase 0.7 should include all auxiliary dataset entries."
283283

284284

285285
@patch("instructlab.sdg.datamixing._create_auxiliary_dataset")
@@ -307,9 +307,9 @@ def test_phase10_creation(mock_auxiliary_dataset):
307307
)
308308

309309
# Check if Phase 1.0 includes knowledge, auxiliary, and knowledge_skills content
310-
assert len(phase10_ds) == phase10_expected_size, (
311-
"Phase 1.0 should contain the expected number of entries, including Phase 0.7 content."
312-
)
310+
assert (
311+
len(phase10_ds) == phase10_expected_size
312+
), "Phase 1.0 should contain the expected number of entries, including Phase 0.7 content."
313313

314314

315315
def test_all_samples_have_unmask_field():
@@ -375,11 +375,11 @@ def test_phase07_knowledge_samples_have_unmask_true():
375375
lambda rec: _conv_pretrain(rec, use_legacy_pretraining_format=False)
376376
)
377377
for sample in auxiliary_ds:
378-
assert sample["unmask"] is True, (
379-
"Auxiliary sample does not have unmask=True"
380-
)
378+
assert (
379+
sample["unmask"] is True
380+
), "Auxiliary sample does not have unmask=True"
381381

382382
# verify that at least ONE sample in phase10 has unmask=True
383-
assert any(sample["unmask"] for sample in phase10_ds), (
384-
"No samples in phase10 have unmask=True"
385-
)
383+
assert any(
384+
sample["unmask"] for sample in phase10_ds
385+
), "No samples in phase10 have unmask=True"

tests/test_generate_data.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,21 @@ def validate_messages_dataset(dataset_file_name, expected_samples):
9393

9494
def validate_skill_leaf_node_dataset(dataset_file_name):
9595
ds = load_dataset("json", data_files=dataset_file_name, split="train")
96-
assert len(ds.features) == 9
96+
assert len(ds.features) == 10
9797
features = [
98-
"task_description",
99-
"seed_context",
100-
"seed_question",
101-
"seed_response",
102-
"output",
103-
"id",
104-
"leaf_node_path",
105-
"leaf_node_type",
98+
("task_description", "string"),
99+
("seed_context", "string"),
100+
("seed_question", "string"),
101+
("seed_response", "string"),
102+
("output", "string"),
103+
("id", "string"),
104+
("leaf_node_path", "string"),
105+
("leaf_node_type", "string"),
106+
("unmask", "bool"),
106107
]
107-
for feature in features:
108+
for feature, dtype in features:
108109
assert feature in ds.features
109-
assert ds.features[feature].dtype == "string"
110+
assert ds.features[feature].dtype == dtype
110111
assert "messages" in ds.features
111112
assert len(ds.features["messages"]) == 1
112113
assert len(ds.features["messages"][0]) == 2
@@ -116,11 +117,11 @@ def validate_skill_leaf_node_dataset(dataset_file_name):
116117

117118
def validate_phase_leaf_node_dataset(dataset_file_name):
118119
ds = load_dataset("json", data_files=dataset_file_name, split="train")
119-
assert len(ds.features) == 3
120-
features = ["metadata", "id"]
121-
for feature in features:
120+
assert len(ds.features) == 4
121+
features = [("metadata", "string"), ("id", "string"), ("unmask", "bool")]
122+
for feature, dtype in features:
122123
assert feature in ds.features
123-
assert ds.features[feature].dtype == "string"
124+
assert ds.features[feature].dtype == dtype
124125
assert "messages" in ds.features
125126
assert len(ds.features["messages"]) == 1
126127
assert len(ds.features["messages"][0]) == 2

0 commit comments

Comments
 (0)