Skip to content

Commit 151dc0c

Browse files
committed
MAINT fix train_evaluator and test_evaluator tests
1 parent 43762a5 commit 151dc0c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

test/test_evaluation/test_test_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_datasets(self):
4444
D_ = copy.deepcopy(D)
4545
y = D.data['Y_train']
4646
if len(y.shape) == 2 and y.shape[1] == 1:
47-
y = y.flatten()
47+
D_.data['Y_train'] = y.flatten()
4848
queue_ = multiprocessing.Queue()
4949
evaluator = TestEvaluator(D_, backend_mock, queue_)
5050

test/test_evaluation/test_train_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def test_datasets(self):
553553
D_ = copy.deepcopy(D)
554554
y = D.data['Y_train']
555555
if len(y.shape) == 2 and y.shape[1] == 1:
556-
y = y.flatten()
556+
D_.data['Y_train'] = y.flatten()
557557
kfold = ShuffleSplit(n=len(y), n_iter=5, random_state=1)
558558
queue_ = multiprocessing.Queue()
559559
evaluator = TrainEvaluator(D_, backend_mock, queue_,

0 commit comments

Comments
 (0)