From 7018bca40d512fe8b7410d2edfe6755580dcaa9f Mon Sep 17 00:00:00 2001 From: Dom <97384583+tosemml@users.noreply.github.com> Date: Sun, 3 Sep 2023 15:00:36 -0700 Subject: [PATCH] Code refactoring --- tests/unit/loader/test_tf_dataloader.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/unit/loader/test_tf_dataloader.py b/tests/unit/loader/test_tf_dataloader.py index b51d068190..6d4e6a0c28 100644 --- a/tests/unit/loader/test_tf_dataloader.py +++ b/tests/unit/loader/test_tf_dataloader.py @@ -382,11 +382,9 @@ def test_mh_support(tmpdir, batch_size): array, offsets = X[f"{mh_name}__values"], X[f"{mh_name}__offsets"] offsets = offsets.numpy() array = array.numpy() - lens = [0] - cur = 0 - for x in multihot_data[mh_name][idx * batch_size : idx * batch_size + n_samples]: - cur += len(x) - lens.append(cur) + m_dta = [len(x) for x in multihot_data[mh_name][idx * batch_size : idx * batch_size + n_samples]] + lens = [0] + np.cumsum(m_dta).tolist() + cur = np.sum(m_dta) assert (offsets == np.array(lens)).all() assert len(array) == max(lens)