Skip to content

Commit dff3144

Browse files
committed
seems to be working.
1 parent 4ccfad0 commit dff3144

File tree

6 files changed

+19
-19
lines changed

6 files changed

+19
-19
lines changed

benchmarks/benchmarking_flux.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
CKPT_ID = "black-forest-labs/FLUX.1-dev"
11+
RESULT_FILENAME = "flux.csv"
1112

1213

1314
def get_input_dict(**device_dtype_kwargs):
@@ -94,4 +95,4 @@ def get_input_dict(**device_dtype_kwargs):
9495
]
9596

9697
runner = BenchmarkMixin()
97-
runner.run_bencmarks_and_collate(scenarios, filename="flux.csv")
98+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)

benchmarks/benchmarking_ltx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
11+
RESULT_FILENAME = "ltx.csv"
1112

1213

1314
def get_input_dict(**device_dtype_kwargs):
@@ -76,4 +77,4 @@ def get_input_dict(**device_dtype_kwargs):
7677
]
7778

7879
runner = BenchmarkMixin()
79-
runner.run_bencmarks_and_collate(scenarios, filename="ltx.csv")
80+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)

benchmarks/benchmarking_sdxl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
11+
RESULT_FILENAME = "sdxl.csv"
1112

1213

1314
def get_input_dict(**device_dtype_kwargs):
@@ -78,4 +79,4 @@ def get_input_dict(**device_dtype_kwargs):
7879
]
7980

8081
runner = BenchmarkMixin()
81-
runner.run_bencmarks_and_collate(scenarios, filename="sdxl.csv")
82+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)

benchmarks/benchmarking_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
11+
RESULT_FILENAME = "wan.csv"
1112

1213

1314
def get_input_dict(**device_dtype_kwargs):
@@ -70,4 +71,4 @@ def get_input_dict(**device_dtype_kwargs):
7071
]
7172

7273
runner = BenchmarkMixin()
73-
runner.run_bencmarks_and_collate(scenarios, filename="wan.csv")
74+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)

benchmarks/push_results.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ def push_to_hf_dataset():
4545

4646
for column in numeric_columns:
4747
# get previous values as floats, aligned to current index
48-
prev_vals = (
49-
previous_results[column]
50-
.map(filter_float)
51-
.reindex(current_results.index)
52-
)
48+
prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)
5349

5450
# get current values as floats
5551
curr_vals = current_results[column].astype(float)
@@ -58,21 +54,16 @@ def push_to_hf_dataset():
5854
curr_str = curr_vals.map(str)
5955

6056
# build an appendage only when prev exists and differs
61-
append_str = prev_vals.where(
62-
prev_vals.notnull() & (prev_vals != curr_vals),
63-
other=pd.NA
64-
).map(lambda x: f" ({x})" if pd.notnull(x) else "")
57+
append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(
58+
lambda x: f" ({x})" if pd.notnull(x) else ""
59+
)
6560

6661
# combine
6762
current_results[column] = curr_str + append_str
6863

6964
current_results.to_csv(FINAL_CSV_FILENAME, index=False)
7065

71-
commit_message = (
72-
f"upload from sha: {GITHUB_SHA}"
73-
if GITHUB_SHA is not None else
74-
"upload benchmark results"
75-
)
66+
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
7667
upload_file(
7768
repo_id=REPO_ID,
7869
path_in_repo=FINAL_CSV_FILENAME,

benchmarks/run_all.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44

55
import pandas as pd
66

7+
from diffusers.utils import logging
8+
79

810
PATTERN = "benchmarking_*.py"
911
FINAL_CSV_FILENAME = "collated_results.csv"
1012
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
1113

1214

15+
logger = logging.get_logger(__name__)
16+
17+
1318
class SubprocessCallException(Exception):
1419
pass
1520

@@ -37,7 +42,7 @@ def run_scripts():
3742

3843
for file in python_files:
3944
if file != "benchmarking_utils.py":
40-
print(f"****** Running file: {file} ******")
45+
logger.info(f"****** Running file: {file} ******")
4146
command = f"python {file}"
4247
run_command(command.split())
4348

0 commit comments

Comments
 (0)