Skip to content

Commit 61ea727

Browse files
Copilotthinkall
andcommitted
Address code review feedback - fix type hints and simplify test logic
Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com>
1 parent bdf6c53 commit 61ea727

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

flaml/automl/automl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def score(
749749

750750
def predict(
751751
self,
752-
X: np.array | DataFrame | list[str] | list[list[str]] | psDataFrame,
752+
X: np.ndarray | DataFrame | list[str] | list[list[str]] | psDataFrame,
753753
**pred_kwargs,
754754
):
755755
"""Predict label from features.
@@ -817,7 +817,7 @@ def predict_proba(self, X, **pred_kwargs):
817817

818818
def preprocess(
819819
self,
820-
X: np.array | DataFrame | list[str] | list[list[str]] | psDataFrame,
820+
X: np.ndarray | DataFrame | list[str] | list[list[str]] | psDataFrame,
821821
):
822822
"""Preprocess data using task-level preprocessing.
823823

test/automl/test_preprocess_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,10 @@ def test_automl_preprocess_with_dataframe(self):
101101
# Test preprocessing
102102
X_preprocessed = automl.preprocess(X_test)
103103

104-
# Verify the output
104+
# Verify the output - check the number of rows matches
105105
self.assertIsNotNone(X_preprocessed)
106-
# The preprocessed data should have the same number of rows
107-
self.assertEqual(len(X_preprocessed) if hasattr(X_preprocessed, '__len__') else X_preprocessed.shape[0],
108-
len(X_test))
106+
preprocessed_len = len(X_preprocessed) if hasattr(X_preprocessed, '__len__') else X_preprocessed.shape[0]
107+
self.assertEqual(preprocessed_len, len(X_test))
109108

110109
def test_estimator_preprocess(self):
111110
"""Test estimator-level preprocessing."""

0 commit comments

Comments
 (0)