Skip to content

Commit d22b2db

Browse files
committed
bug fix
1 parent 7f4d7ac commit d22b2db

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ def _refit(self, context: Context) -> None:
229229

230230
context.data_handler.prepare_for_refit()
231231

232-
scoring_module.fit(context.data_handler.train_utterances(0), context.data_handler.train_labels(0)) # type: ignore[arg-type]
232+
scoring_module.fit(*scoring_module.get_train_data(context)) # type: ignore[arg-type]
233233
scores = scoring_module.predict(context.data_handler.train_utterances(1))
234234

235-
decision_module.fit(scores, context.data_handler.train_labels(1))
235+
decision_module.fit(scores, context.data_handler.train_labels(1), context.data_handler.tags)
236236

237237
def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
238238
"""

tests/data/test_stratificaiton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,6 @@ def test_multilabel_train_test_split(dataset_unsplitted):
3838

3939
assert Split.TRAIN in dataset
4040
assert Split.TEST in dataset
41-
assert dataset[Split.TRAIN].num_rows == 17
42-
assert dataset[Split.TEST].num_rows == 19
41+
assert dataset[Split.TRAIN].num_rows == 19
42+
assert dataset[Split.TEST].num_rows == 17
4343
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)

0 commit comments

Comments
 (0)