Skip to content

Commit f4a9e37

Browse files
committed
Fix isinstance usage issues
1 parent e64b486 commit f4a9e37

File tree

7 files changed

+18
-22
lines changed

7 files changed

+18
-22
lines changed

flaml/automl/ml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,14 @@ def get_y_pred(estimator, X, eval_metric, task: Task):
311311
else:
312312
y_pred = estimator.predict(X)
313313

314-
if isinstance(y_pred, Series) or isinstance(y_pred, DataFrame):
314+
if isinstance(y_pred, (Series, DataFrame)):
315315
y_pred = y_pred.values
316316

317317
return y_pred
318318

319319

320320
def to_numpy(x):
321-
if isinstance(x, Series or isinstance(x, DataFrame)):
321+
if isinstance(x, (Series, DataFrame)):
322322
x = x.values
323323
else:
324324
x = np.ndarray(x)
@@ -586,7 +586,7 @@ def _eval_estimator(
586586

587587
# TODO: why are integer labels being cast to str in the first place?
588588

589-
if isinstance(val_pred_y, Series) or isinstance(val_pred_y, DataFrame) or isinstance(val_pred_y, np.ndarray):
589+
if isinstance(val_pred_y, (Series, DataFrame, np.ndarray)):
590590
test = val_pred_y if isinstance(val_pred_y, np.ndarray) else val_pred_y.values
591591
if not np.issubdtype(test.dtype, np.number):
592592
# some NLP models return a list

flaml/automl/nlp/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def load_default_huggingface_metric_for_task(task):
2525

2626

2727
def is_a_list_of_str(this_obj):
28-
return (isinstance(this_obj, list) or isinstance(this_obj, np.ndarray)) and all(
29-
isinstance(x, str) for x in this_obj
30-
)
28+
return isinstance(this_obj, (list, np.ndarray)) and all(isinstance(x, str) for x in this_obj)
3129

3230

3331
def _clean_value(value: Any) -> str:

flaml/automl/state.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ def valid_starting_point_one_dim(self, value_one_dim, domain_one_dim):
3737
if isinstance(domain_one_dim, sample.Domain):
3838
renamed_type = list(inspect.signature(domain_one_dim.is_valid).parameters.values())[0].annotation
3939
type_match = (
40-
renamed_type == Any
40+
renamed_type is Any
4141
or isinstance(value_one_dim, renamed_type)
42-
or isinstance(value_one_dim, int)
43-
and renamed_type is float
42+
or (renamed_type is float and isinstance(value_one_dim, int))
4443
)
4544
if not (type_match and domain_one_dim.is_valid(value_one_dim)):
4645
return False

flaml/automl/task/time_series_task.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,8 @@ def _preprocess(self, X, transformer=None):
386386
return X
387387

388388
def preprocess(self, X, transformer=None):
389-
if isinstance(X, pd.DataFrame) or isinstance(X, np.ndarray) or isinstance(X, pd.Series):
390-
X = X.copy()
391-
X = normalize_ts_data(X, self.target_names, self.time_col)
389+
if isinstance(X, (pd.DataFrame, np.ndarray, pd.Series)):
390+
X = normalize_ts_data(X.copy(), self.target_names, self.time_col)
392391
return self._preprocess(X, transformer)
393392
elif isinstance(X, int):
394393
return X

flaml/automl/time_series/ts_data.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -546,14 +546,12 @@ def normalize_ts_data(X_train_all, target_names, time_col, y_train_all=None):
546546

547547

548548
def validate_data_basic(X_train_all, y_train_all):
549-
assert isinstance(X_train_all, np.ndarray) or issparse(X_train_all) or isinstance(X_train_all, pd.DataFrame), (
550-
"X_train_all must be a numpy array, a pandas dataframe, " "or Scipy sparse matrix."
551-
)
549+
assert isinstance(X_train_all, (np.ndarray, DataFrame)) or issparse(
550+
X_train_all
551+
), "X_train_all must be a numpy array, a pandas dataframe, or Scipy sparse matrix."
552552

553-
assert (
554-
isinstance(y_train_all, np.ndarray)
555-
or isinstance(y_train_all, pd.Series)
556-
or isinstance(y_train_all, pd.DataFrame)
553+
assert isinstance(
554+
y_train_all, (np.ndarray, pd.Series, pd.DataFrame)
557555
), "y_train_all must be a numpy array or a pandas series or DataFrame."
558556

559557
assert X_train_all.size != 0 and y_train_all.size != 0, "Input data must not be empty, use None if no data"

flaml/tune/searcher/flow2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,10 @@ def config_signature(self, config, space: Dict = None) -> tuple:
641641
else:
642642
# key must be in space
643643
domain = space[key]
644-
if self.hierarchical and not (
645-
domain is None or type(domain) in (str, int, float) or isinstance(domain, sample.Domain)
644+
if (
645+
self.hierarchical
646+
and domain is not None
647+
and not isinstance(domain, (str, int, float, sample.Domain))
646648
):
647649
# not domain or hashable
648650
# get rid of list type for hierarchical search space.

flaml/tune/searcher/online_searcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _query_config_oracle(
207207
hyperparameter_config_groups.append(partial_new_configs)
208208
# does not have searcher_trial_ids
209209
searcher_trial_ids_groups.append([])
210-
elif isinstance(config_domain, Float) or isinstance(config_domain, Categorical):
210+
elif isinstance(config_domain, (Float, Categorical)):
211211
# otherwise we need to deal with them in group
212212
nonpoly_config[k] = v
213213
if k not in self._space_of_nonpoly_hp:

0 commit comments

Comments
 (0)