Skip to content

Commit f275062

Browse files
committed
Fix num_batches and batch_size reading from config
1 parent 69cc4c1 commit f275062

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

configs/sklearnex_incremental_example.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@
5656
}
5757
},
5858
"TEMPLATES": {
59-
"covariance": {"SETS": ["common", "covariance", "unlabeled dataset"]},
60-
"linear_regression": {
61-
"SETS": ["common", "linear_regression", "labeled dataset"]
62-
},
63-
"pca": {"SETS": ["common", "pca", "unlabeled dataset"]}
59+
"covariance": {"SETS": ["common", "covariance", "unlabeled dataset"]}
6460
}
6561
}

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,12 @@ def measure_sklearn_estimator(
425425
data_args = (x_test,)
426426

427427
if method == "partial_fit":
428-
num_batches = get_bench_case_value(bench_case, "data:num_batches")
429-
batch_size = get_bench_case_value(bench_case, "data:batch_size")
428+
num_batches = get_bench_case_value(
429+
bench_case, f"algorithm:num_batches:{stage}"
430+
)
431+
batch_size = get_bench_case_value(
432+
bench_case, f"algorithm:batch_size:{stage}"
433+
)
430434

431435
if batch_size is None:
432436
if num_batches is None:

sklbench/report/implementation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def get_result_tables_as_df(
236236
splitby_columns=["estimator", "method", "function"],
237237
compatibility_mode=False,
238238
):
239+
print(results["bench_cases"])
239240
bench_cases = pd.DataFrame(
240241
[flatten_dict(bench_case) for bench_case in results["bench_cases"]]
241242
)
@@ -244,6 +245,7 @@ def get_result_tables_as_df(
244245
if compatibility_mode:
245246
bench_cases = transform_results_to_compatible(bench_cases)
246247

248+
print(bench_cases)
247249
for column in diffby_columns.copy():
248250
if bench_cases[column].nunique() == 1:
249251
bench_cases.drop(columns=[column], inplace=True)

0 commit comments

Comments
 (0)