Skip to content

Commit b1f2c15

Browse files
md-shafiul-alamolegkkruglov
authored andcommitted
lint
1 parent 192744f commit b1f2c15

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def get_subset_metrics_of_estimator(
145145
"balanced accuracy": float(balanced_accuracy_score(y_compat, y_pred)),
146146
}
147147
)
148-
'''if hasattr(estimator_instance, "predict_proba") and not (
148+
"""if hasattr(estimator_instance, "predict_proba") and not (
149149
hasattr(estimator_instance, "probability")
150150
and getattr(estimator_instance, "probability") == False
151151
):
@@ -165,7 +165,7 @@ def get_subset_metrics_of_estimator(
165165
),
166166
"logloss": float(log_loss(y_compat, y_pred_proba)),
167167
}
168-
)'''
168+
)"""
169169
elif task == "regression":
170170
y_pred = convert_to_numpy(estimator_instance.predict(x))
171171
metrics.update(
@@ -463,7 +463,7 @@ def measure_sklearn_estimator(
463463
metrics[method]["time std[ms]"],
464464
metrics[method]["first iter[ms]"],
465465
metrics[method]["box filter mean[ms]"],
466-
metrics[method]["box filter std[ms]"]
466+
metrics[method]["box filter std[ms]"],
467467
) = measure_case(bench_case, method_instance, *data_args)
468468
if ensure_sklearnex_patching:
469469
full_method_name = f"{estimator_class.__name__}.{method}"
@@ -546,7 +546,7 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
546546
result_template = enrich_result(result_template, bench_case)
547547
if "assume_finite" in context_params:
548548
result_template["assume_finite"] = context_params["assume_finite"]
549-
#if hasattr(estimator_instance, "get_params"):
549+
# if hasattr(estimator_instance, "get_params"):
550550
# estimator_params = estimator_instance.get_params()
551551
# note: "handle" is not JSON-serializable
552552
if "handle" in estimator_params:

sklbench/datasets/transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ def split_and_transform_data(bench_case, data, data_description):
109109
y_train, y_test = None, None
110110

111111
distributed_split = get_bench_case_value(bench_case, "data:distributed_split", None)
112-
knn_split_train = "KNeighbors" in get_bench_case_value(bench_case, "algorithm:estimator", "") and int(get_bench_case_value(bench_case, "bench:mpi_params:n", 1)) > 1
112+
knn_split_train = (
113+
"KNeighbors" in get_bench_case_value(bench_case, "algorithm:estimator", "")
114+
and int(get_bench_case_value(bench_case, "bench:mpi_params:n", 1)) > 1
115+
)
113116
if distributed_split == "rank_based" or knn_split_train:
114117
from mpi4py import MPI
115118

sklbench/runner/commands_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def generate_benchmark_command(
4747
mpi_prefix += f" -{mpi_param_name} {mpi_param_value}"
4848
if mpi_param_name == "-hostfile":
4949
import os
50+
5051
mpi_prefix += os.environ.get("PBS_NODEFILE")
5152
command_prefix = f"{mpi_prefix} {command_prefix}"
5253
# 3. Intel(R) VTune* profiling command prefix

sklbench/utils/measurement.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def large_scale_measurements(timing):
4848
Q1, Q3 = np.percentile(timing_sorted, [25, 75])
4949
IQ = Q3 - Q1
5050
lower, upper = Q1 - 1.5 * IQ, Q3 + 1.5 * IQ
51-
51+
5252
filtered_times = timing_sorted[(timing_sorted >= lower) & (timing_sorted <= upper)]
53-
53+
5454
box_filter_mean = np.mean(filtered_times) * 1000 if filtered_times.size > 0 else 0
5555
box_filter_stdev = np.std(filtered_times) * 1000 if filtered_times.size > 0 else 0
5656
return mean, stdev, first_iter, box_filter_mean, box_filter_stdev
@@ -89,8 +89,8 @@ def measure_time(
8989
)
9090
break
9191
logger.debug(times)
92-
#mean, std = box_filter(times)
93-
#if std / mean > std_mean_ratio:
92+
# mean, std = box_filter(times)
93+
# if std / mean > std_mean_ratio:
9494
# logger.warning(
9595
# f'Measured "std / mean" time ratio of "{str(func)}" function is higher '
9696
# f"than threshold ({round(std / mean, 3)} vs. {std_mean_ratio})"

0 commit comments

Comments
 (0)