Skip to content

Commit be7a7ee

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f4fadd1 commit be7a7ee

File tree

3 files changed

+141
-63
lines changed

3 files changed

+141
-63
lines changed

dev-scripts/benchmark_bottlenecks.py

Lines changed: 91 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,24 @@ def build_model(n_vars: int, n_cons: int, terms_per_con: int = 5) -> Model:
3737

3838
n_vars_per_dim = int(np.sqrt(n_vars)) + 1
3939
x = m.add_variables(
40-
lower=0, upper=100, name="x",
40+
lower=0,
41+
upper=100,
42+
name="x",
4143
coords=[range(n_vars_per_dim), range(n_vars_per_dim)],
4244
)
4345
y = m.add_variables(
44-
lower=-50, upper=50, name="y",
46+
lower=-50,
47+
upper=50,
48+
name="y",
4549
coords=[range(n_vars_per_dim), range(n_vars_per_dim)],
4650
)
4751

4852
for i in range(n_cons):
4953
var_indices = rng.integers(0, n_vars_per_dim, size=(terms_per_con, 2))
5054
coeffs = rng.uniform(-10, 10, size=terms_per_con)
5155
lhs = sum(
52-
coeffs[j] * (x if j % 2 == 0 else y).isel(
56+
coeffs[j]
57+
* (x if j % 2 == 0 else y).isel(
5358
dim_0=var_indices[j, 0], dim_1=var_indices[j, 1]
5459
)
5560
for j in range(terms_per_con)
@@ -59,7 +64,9 @@ def build_model(n_vars: int, n_cons: int, terms_per_con: int = 5) -> Model:
5964
return m
6065

6166

62-
def time_function(func: Callable[[], Any], repeats: int, warmup: int = 2) -> Iterable[float]:
67+
def time_function(
68+
func: Callable[[], Any], repeats: int, warmup: int = 2
69+
) -> Iterable[float]:
6370
"""Time a function over multiple iterations."""
6471
for _ in range(warmup):
6572
func()
@@ -69,9 +76,7 @@ def time_function(func: Callable[[], Any], repeats: int, warmup: int = 2) -> Ite
6976
yield time.perf_counter() - start
7077

7178

72-
def run_bottleneck_analysis(
73-
model: Model, n_lookups: int, repeats: int
74-
) -> xr.Dataset:
79+
def run_bottleneck_analysis(model: Model, n_lookups: int, repeats: int) -> xr.Dataset:
7580
"""
7681
Run bottleneck analysis on individual methods.
7782
@@ -122,15 +127,19 @@ def bench_optimized_var():
122127

123128
times = np.fromiter(time_function(bench_optimized_var, repeats), dtype=float)
124129
results["get_label_position_vars_optimized"] = xr.DataArray(
125-
times, dims=["repeat"], attrs={"n_operations": n_lookups, "complexity": "O(log n)"}
130+
times,
131+
dims=["repeat"],
132+
attrs={"n_operations": n_lookups, "complexity": "O(log n)"},
126133
)
127134

128135
def bench_optimized_con():
129136
return [constraints.get_label_position(int(l)) for l in con_labels]
130137

131138
times = np.fromiter(time_function(bench_optimized_con, repeats), dtype=float)
132139
results["get_label_position_cons_optimized"] = xr.DataArray(
133-
times, dims=["repeat"], attrs={"n_operations": n_lookups, "complexity": "O(log n)"}
140+
times,
141+
dims=["repeat"],
142+
attrs={"n_operations": n_lookups, "complexity": "O(log n)"},
134143
)
135144

136145
# 3. .sel() calls on constraints
@@ -144,7 +153,12 @@ def bench_sel_calls():
144153

145154
times = np.fromiter(time_function(bench_sel_calls, repeats), dtype=float)
146155
results["sel_calls_constraint"] = xr.DataArray(
147-
times, dims=["repeat"], attrs={"n_operations": n_lookups, "description": "4 sel() calls per constraint"}
156+
times,
157+
dims=["repeat"],
158+
attrs={
159+
"n_operations": n_lookups,
160+
"description": "4 sel() calls per constraint",
161+
},
148162
)
149163

150164
# 4. .sel() calls on variables
@@ -156,7 +170,9 @@ def bench_sel_var():
156170

157171
times = np.fromiter(time_function(bench_sel_var, repeats), dtype=float)
158172
results["sel_calls_variable"] = xr.DataArray(
159-
times, dims=["repeat"], attrs={"n_operations": n_lookups, "description": "2 sel() calls per variable"}
173+
times,
174+
dims=["repeat"],
175+
attrs={"n_operations": n_lookups, "description": "2 sel() calls per variable"},
160176
)
161177

162178
# 5. Nested variable lookups (as in print_single_constraint)
@@ -165,7 +181,12 @@ def bench_nested_var_lookup():
165181

166182
times = np.fromiter(time_function(bench_nested_var_lookup, repeats), dtype=float)
167183
results["nested_var_lookup"] = xr.DataArray(
168-
times, dims=["repeat"], attrs={"n_operations": len(nested_var_labels), "description": "Variable lookups from constraint terms"}
184+
times,
185+
dims=["repeat"],
186+
attrs={
187+
"n_operations": len(nested_var_labels),
188+
"description": "Variable lookups from constraint terms",
189+
},
169190
)
170191

171192
# 6. print_coord formatting
@@ -176,7 +197,12 @@ def bench_print_coord():
176197

177198
times = np.fromiter(time_function(bench_print_coord, repeats), dtype=float)
178199
results["print_coord"] = xr.DataArray(
179-
times, dims=["repeat"], attrs={"n_operations": len(coords_list), "description": "Coordinate formatting"}
200+
times,
201+
dims=["repeat"],
202+
attrs={
203+
"n_operations": len(coords_list),
204+
"description": "Coordinate formatting",
205+
},
180206
)
181207

182208
# 7. String formatting
@@ -186,7 +212,9 @@ def bench_string_format():
186212

187213
times = np.fromiter(time_function(bench_string_format, repeats), dtype=float)
188214
results["string_formatting"] = xr.DataArray(
189-
times, dims=["repeat"], attrs={"n_operations": n_lookups, "description": "f-string formatting"}
215+
times,
216+
dims=["repeat"],
217+
attrs={"n_operations": n_lookups, "description": "f-string formatting"},
190218
)
191219

192220
# Create dataset
@@ -210,8 +238,14 @@ def compute_summary(ds: xr.Dataset) -> xr.Dataset:
210238
times = ds[var_name].values
211239
n_ops = ds[var_name].attrs.get("n_operations", 1)
212240
data[var_name] = xr.DataArray(
213-
[np.median(times), np.mean(times), np.std(times), (np.median(times) / n_ops) * 1e6],
214-
dims=["stat"], coords={"stat": stats}
241+
[
242+
np.median(times),
243+
np.mean(times),
244+
np.std(times),
245+
(np.median(times) / n_ops) * 1e6,
246+
],
247+
dims=["stat"],
248+
coords={"stat": stats},
215249
)
216250

217251
summary = xr.Dataset(data)
@@ -225,8 +259,12 @@ def print_analysis(ds: xr.Dataset, summary: xr.Dataset) -> None:
225259
print("BOTTLENECK ANALYSIS")
226260
print("=" * 80)
227261

228-
print(f"\nModel: {ds.attrs['n_variables']} variables, {ds.attrs['n_constraints']} constraints")
229-
print(f" {ds.attrs['n_variable_arrays']} variable arrays, {ds.attrs['n_constraint_arrays']} constraint arrays")
262+
print(
263+
f"\nModel: {ds.attrs['n_variables']} variables, {ds.attrs['n_constraints']} constraints"
264+
)
265+
print(
266+
f" {ds.attrs['n_variable_arrays']} variable arrays, {ds.attrs['n_constraint_arrays']} constraint arrays"
267+
)
230268
print(f" {ds.attrs['n_lookups']} lookups per benchmark\n")
231269

232270
# get_label_position comparison
@@ -242,7 +280,9 @@ def print_analysis(ds: xr.Dataset, summary: xr.Dataset) -> None:
242280
median_ms = float(summary[key].sel(stat="median_s")) * 1000
243281
per_op = float(summary[key].sel(stat="per_op_us"))
244282
complexity = ds[key].attrs.get("complexity", "")
245-
print(f" {key:<45s} {median_ms:>10.2f}ms {per_op:>10.2f}µs {complexity:<10s}")
283+
print(
284+
f" {key:<45s} {median_ms:>10.2f}ms {per_op:>10.2f}µs {complexity:<10s}"
285+
)
246286

247287
# Speedup calculation
248288
print("\nSpeedup (Optimized vs Original):")
@@ -262,7 +302,13 @@ def print_analysis(ds: xr.Dataset, summary: xr.Dataset) -> None:
262302
print(f" {'Operation':<45s} {'Time':>12s} {'Per Op':>12s}")
263303
print("-" * 80)
264304

265-
other_ops = ["sel_calls_constraint", "sel_calls_variable", "nested_var_lookup", "print_coord", "string_formatting"]
305+
other_ops = [
306+
"sel_calls_constraint",
307+
"sel_calls_variable",
308+
"nested_var_lookup",
309+
"print_coord",
310+
"string_formatting",
311+
]
266312
for key in other_ops:
267313
if key in summary.data_vars:
268314
median_ms = float(summary[key].sel(stat="median_s")) * 1000
@@ -282,22 +328,36 @@ def print_analysis(ds: xr.Dataset, summary: xr.Dataset) -> None:
282328

283329
total = orig_con + sel_con + nested + fmt
284330
print("\nprint_single_constraint breakdown (with original get_label_position):")
285-
print(f" Constraint lookup: {orig_con*1000:>8.2f}ms ({orig_con/total*100:>5.1f}%)")
286-
print(f" .sel() calls: {sel_con*1000:>8.2f}ms ({sel_con/total*100:>5.1f}%)")
287-
print(f" Nested var lookups: {nested*1000:>8.2f}ms ({nested/total*100:>5.1f}%)")
288-
print(f" String formatting: {fmt*1000:>8.2f}ms ({fmt/total*100:>5.1f}%)")
289-
print(f" Total: {total*1000:>8.2f}ms")
331+
print(
332+
f" Constraint lookup: {orig_con * 1000:>8.2f}ms ({orig_con / total * 100:>5.1f}%)"
333+
)
334+
print(
335+
f" .sel() calls: {sel_con * 1000:>8.2f}ms ({sel_con / total * 100:>5.1f}%)"
336+
)
337+
print(
338+
f" Nested var lookups: {nested * 1000:>8.2f}ms ({nested / total * 100:>5.1f}%)"
339+
)
340+
print(f" String formatting: {fmt * 1000:>8.2f}ms ({fmt / total * 100:>5.1f}%)")
341+
print(f" Total: {total * 1000:>8.2f}ms")
290342

291343
# With optimized implementation
292344
opt_con = float(summary["get_label_position_cons_optimized"].sel(stat="median_s"))
293345
total_opt = opt_con + sel_con + nested + fmt
294346
print("\nWith optimized get_label_position:")
295-
print(f" Constraint lookup: {opt_con*1000:>8.2f}ms ({opt_con/total_opt*100:>5.1f}%)")
296-
print(f" .sel() calls: {sel_con*1000:>8.2f}ms ({sel_con/total_opt*100:>5.1f}%)")
297-
print(f" Nested var lookups: {nested*1000:>8.2f}ms ({nested/total_opt*100:>5.1f}%)")
298-
print(f" String formatting: {fmt*1000:>8.2f}ms ({fmt/total_opt*100:>5.1f}%)")
299-
print(f" Total: {total_opt*1000:>8.2f}ms")
300-
print(f"\n Overall speedup: {total/total_opt:.1f}x")
347+
print(
348+
f" Constraint lookup: {opt_con * 1000:>8.2f}ms ({opt_con / total_opt * 100:>5.1f}%)"
349+
)
350+
print(
351+
f" .sel() calls: {sel_con * 1000:>8.2f}ms ({sel_con / total_opt * 100:>5.1f}%)"
352+
)
353+
print(
354+
f" Nested var lookups: {nested * 1000:>8.2f}ms ({nested / total_opt * 100:>5.1f}%)"
355+
)
356+
print(
357+
f" String formatting: {fmt * 1000:>8.2f}ms ({fmt / total_opt * 100:>5.1f}%)"
358+
)
359+
print(f" Total: {total_opt * 1000:>8.2f}ms")
360+
print(f"\n Overall speedup: {total / total_opt:.1f}x")
301361
print()
302362

303363

0 commit comments

Comments
 (0)