Skip to content

Commit c69d6a1

Browse files
[ENH] Collection conversion cleanup and df-list fix (#2654)
* collection conversion cleanup * notebook * fixes --------- Co-authored-by: Tony Bagnall <[email protected]>
1 parent d1d1ae5 commit c69d6a1

File tree

12 files changed

+673
-420
lines changed

12 files changed

+673
-420
lines changed

aeon/classification/early_classification/tests/test_probability_threshold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_early_prob_threshold_near_classification_points():
3232
X = X_test[:, :, :i]
3333

3434
if i == 20:
35-
with pytest.raises(ValueError):
35+
with pytest.raises(IndexError):
3636
pt.update_predict_proba(X)
3737
else:
3838
_, decisions = pt.update_predict_proba(X)

aeon/classification/early_classification/tests/test_teaser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_teaser_near_classification_points():
8080
X = X_test[:, :, :i]
8181

8282
if i == 20:
83-
with pytest.raises(ValueError):
83+
with pytest.raises(IndexError):
8484
teaser.update_predict_proba(X)
8585
else:
8686
_, decisions = teaser.update_predict(X)

aeon/testing/data_generation/_collection.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -408,15 +408,11 @@ def make_example_dataframe_list(
408408
... random_state=0,
409409
... )
410410
>>> print(data)
411-
[ 0 1
412-
0 0.000000 1.688531
413-
1 1.715891 1.694503
414-
2 1.247127 0.768763
415-
3 0.595069 0.113426, 0 1
416-
0 2.000000 3.166900
417-
1 2.115580 2.272178
418-
2 3.702387 0.284144
419-
3 0.348517 0.080874]
411+
[ 0 1 2 3
412+
0 0.000000 1.688531 1.715891 1.694503
413+
1 1.247127 0.768763 0.595069 0.113426, 0 1 2 3
414+
0 2.000000 3.166900 2.115580 2.272178
415+
1 3.702387 0.284144 0.348517 0.080874]
420416
>>> print(labels)
421417
[0 1]
422418
>>> get_type(data)
@@ -428,14 +424,14 @@ def make_example_dataframe_list(
428424

429425
for i in range(n_cases):
430426
n_timepoints = rng.randint(min_n_timepoints, max_n_timepoints + 1)
431-
x = n_labels * rng.uniform(size=(n_timepoints, n_channels))
427+
x = n_labels * rng.uniform(size=(n_channels, n_timepoints))
432428
label = x[0, 0].astype(int)
433429
if i < n_labels and n_cases > i:
434430
x[0, 0] = i
435431
label = i
436432
x = x * (label + 1)
437433

438-
X.append(pd.DataFrame(x, index=range(n_timepoints), columns=range(n_channels)))
434+
X.append(pd.DataFrame(x, index=range(n_channels), columns=range(n_timepoints)))
439435
y[i] = label
440436

441437
if regression_target:
@@ -574,16 +570,16 @@ def make_example_multi_index_dataframe(
574570
... random_state=0,
575571
... )
576572
>>> print(data) # doctest: +NORMALIZE_WHITESPACE
577-
channel_0 channel_1
573+
channel 0 1
578574
case timepoint
579-
0 0 0.000000 1.247127
580-
1 1.688531 0.768763
581-
2 1.715891 0.595069
582-
3 1.694503 0.113426
583-
1 0 2.000000 3.702387
584-
1 3.166900 0.284144
585-
2 2.115580 0.348517
586-
3 2.272178 0.080874
575+
0 0 0.000000 1.247127
576+
1 1.688531 0.768763
577+
2 1.715891 0.595069
578+
3 1.694503 0.113426
579+
1 0 2.000000 3.702387
580+
1 3.166900 0.284144
581+
2 2.115580 0.348517
582+
3 2.272178 0.080874
587583
>>> print(labels)
588584
[0 1]
589585
>>> get_type(data)
@@ -616,8 +612,7 @@ def make_example_multi_index_dataframe(
616612
y[i] = label
617613

618614
X = X.reset_index(drop=True)
619-
X = X.set_index(["case", "timepoint"]).pivot(columns="channel")
620-
X.columns = [f"channel_{i}" for i in range(n_channels)]
615+
X = X.pivot(index=["case", "timepoint"], columns=["channel"], values="value")
621616

622617
if regression_target:
623618
y = y.astype(np.float32)

aeon/testing/data_generation/tests/test_collection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def test_make_example_dataframe_list(
178178
assert all(isinstance(x, pd.DataFrame) for x in X)
179179
assert isinstance(y, np.ndarray)
180180
assert len(X) == n_cases
181-
assert all([x.shape[1] == n_channels for x in X])
181+
assert all([x.shape[0] == n_channels for x in X])
182182
if min_n_timepoints == max_n_timepoints:
183-
assert all([x.shape[0] == min_n_timepoints for x in X])
183+
assert all([x.shape[1] == min_n_timepoints for x in X])
184184
else:
185185
assert all(
186186
[
187-
x.shape[0] >= min_n_timepoints and x.shape[0] <= max_n_timepoints
187+
x.shape[1] >= min_n_timepoints and x.shape[1] <= max_n_timepoints
188188
for x in X
189189
]
190190
)

aeon/testing/tests/test_testing_data.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,9 @@ def test_equal_length_univariate_collection():
122122
assert not is_collection(
123123
EQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
124124
)
125-
assert is_univariate(
126-
EQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
127-
is_collection=False,
128-
)
125+
# assert is_univariate(
126+
# EQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
127+
# )
129128
assert is_equal_length(
130129
EQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
131130
)
@@ -199,10 +198,9 @@ def test_unequal_length_univariate_collection():
199198
assert not is_collection(
200199
UNEQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
201200
)
202-
assert is_univariate(
203-
UNEQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
204-
is_collection=False,
205-
)
201+
# assert is_univariate(
202+
# UNEQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
203+
# )
206204
assert is_equal_length(
207205
UNEQUAL_LENGTH_UNIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
208206
)
@@ -276,10 +274,9 @@ def test_equal_length_multivariate_collection():
276274
assert not is_collection(
277275
EQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
278276
)
279-
assert not is_univariate(
280-
EQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
281-
is_collection=False,
282-
)
277+
# assert not is_univariate(
278+
# EQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
279+
# )
283280
assert is_equal_length(
284281
EQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
285282
)
@@ -365,10 +362,9 @@ def test_unequal_length_multivariate_collection():
365362
assert not is_collection(
366363
UNEQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
367364
)
368-
assert not is_univariate(
369-
UNEQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
370-
is_collection=False,
371-
)
365+
# assert not is_univariate(
366+
# UNEQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0],
367+
# )
372368
assert is_equal_length(
373369
UNEQUAL_LENGTH_MULTIVARIATE_SIMILARITY_SEARCH[key]["test"][0]
374370
)

aeon/testing/utils/deep_equals.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def _deep_equals(x, y, depth, ignore_index):
8484
def _series_equals(x, y, depth, ignore_index):
8585
if x.dtype != y.dtype:
8686
return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype}), depth={depth}"
87+
if x.shape != y.shape:
88+
return False, f"x.shape ({x.shape}) != y.shape ({y.shape}), depth={depth}"
8789

8890
# if columns are object, recurse over entries and index
8991
if x.dtype == "object":
@@ -108,7 +110,12 @@ def _series_equals(x, y, depth, ignore_index):
108110

109111
def _dataframe_equals(x, y, depth, ignore_index):
110112
if not x.columns.equals(y.columns):
111-
return False, f"x.columns ({x.columns}) != y.columns ({y.columns})"
113+
return (
114+
False,
115+
f"x.columns ({x.columns}) != y.columns ({y.columns}), depth={depth}",
116+
)
117+
if x.shape != y.shape:
118+
return False, f"x.shape ({x.shape}) != y.shape ({y.shape}), depth={depth}"
112119

113120
# if columns are equal and at least one is object, recurse over Series
114121
if sum(x.dtypes == "object") > 0:
@@ -130,7 +137,9 @@ def _dataframe_equals(x, y, depth, ignore_index):
130137

131138
def _numpy_equals(x, y, depth, ignore_index):
132139
if x.dtype != y.dtype:
133-
return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})"
140+
return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype}), depth={depth}"
141+
if x.shape != y.shape:
142+
return False, f"x.shape ({x.shape}) != y.shape ({y.shape}), depth={depth}"
134143

135144
if x.dtype == "object":
136145
for i in range(len(x)):

0 commit comments

Comments
 (0)