Skip to content

Commit ab9c7b9

Browse files
authored
fix: replace func "len()" in ensemble test code to support various data type (#739)
* replace len with def get_length * update get_length implementation
1 parent bc8e846 commit ab9c7b9

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

rdagent/components/coder/data_science/ensemble/eval_tests/ensemble_test.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def print_preds_info(model_name, data_type, preds):
3434
else:
3535
print(f"Unknown prediction type: {type(preds)}")
3636

37+
def get_length(data):
38+
return data.shape[0] if hasattr(data, 'shape') else len(data)
39+
3740
X, y, test_X, test_ids = load_data()
3841
X, y, test_X = feat_eng(X, y, test_X)
3942
train_X, val_X, train_y, val_y = train_test_split(X, y, test_size=0.2, random_state=42)
@@ -109,9 +112,9 @@ assert isinstance(final_pred, pred_type), (
109112

110113
# Check shape
111114
if isinstance(final_pred, (list, np.ndarray, pd.DataFrame, torch.Tensor, tf.Tensor)):
112-
assert len(final_pred) == len(test_X), (
113-
f"Wrong output sample size: len(final_pred)={len(final_pred)} "
114-
f"vs. len(test_X)={len(test_X)}"
115+
assert get_length(final_pred) == get_length(test_X), (
116+
f"Wrong output sample size: get_length(final_pred)={get_length(final_pred)} "
117+
f"vs. get_length(test_X)={get_length(test_X)}"
115118
)
116119

117120
# check scores.csv

0 commit comments

Comments
 (0)