Skip to content

Commit 7b4f37e

Browse files
Merge pull request #33 from TRI-ML/comparison_tools
Multi-policy comparison convenience tools
2 parents 6db47ef + 33443cf commit 7b4f37e

File tree

6 files changed

+762
-0
lines changed

6 files changed

+762
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
*.ipynb_checkpoints*
33
*.pytest_cache*
44
*.egg-info*
5+
dist/*

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ decision = result.decision
126126
print(decision) # FailToDecide: difference was not statistically separable; user can collect 100 - 4 = 96 more rollouts for each policy to re-run the test.
127127
```
128128

129+
### More Working Examples
130+
- `quick_start.ipynb` presents a single-task policy comparison example using actual hardware evaluation results.
131+
- `multi_policy_comparison_example.ipynb` describes how to compare multiple policies at once, both on single task and multiple tasks.
132+
129133
## Key Notes for Understanding the Core Ideas of STEP Code
130134

131135
We include key notes for understanding the core ideas of the STEP code. Quick-start resources are included in both shell script and notebook form.

multi_policy_comparison_example.ipynb

Lines changed: 362 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""A collection of convenience tools for analysis and presentation of test results."""
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
from typing import Dict, List, Optional, Tuple, Union
2+
3+
from matplotlib.cm import get_cmap
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
from scipy import stats
7+
8+
from sequentialized_barnard_tests import Decision, Hypothesis
9+
from sequentialized_barnard_tests.step import MirroredStepTest
10+
11+
12+
def compact_letter_display(
13+
significant_pair_list: List[Tuple[str, str]],
14+
sorted_model_list: List[str],
15+
) -> List[str]:
16+
"""Generates Compact Letter Display (CLD) given a list of significant
17+
pairs and a list of models. CLD is Based on "An Algorithm for a
18+
Letter-Based Representation of All-Pairwise Comparisons" by Piepho
19+
(2004).
20+
21+
Args:
22+
significant_pair_list: A list containing tuples of model names that
23+
were deemed significantly different by each A/B test.
24+
sorted_model_list: A list of model names sorted by performance in
25+
descending order.
26+
27+
Returns:
28+
A list of letters representing CLD for the corresponding models.
29+
"""
30+
num_models = len(sorted_model_list)
31+
32+
# Map model names to indices.
33+
model_to_index = {model: idx for idx, model in enumerate(sorted_model_list)}
34+
# Convert significant pairs from names to indices.
35+
significant_index_pairs = [
36+
(model_to_index[m1], model_to_index[m2]) for m1, m2 in significant_pair_list
37+
]
38+
39+
# --- Inner helper to remove redundant columns ---
40+
def remove_redundant_columns(matrix):
41+
changed = True
42+
while changed:
43+
changed = False
44+
for i in range(len(matrix)):
45+
for j in range(len(matrix)):
46+
if i != j:
47+
indices_i = {idx for idx, char in enumerate(matrix[i]) if char}
48+
indices_j = {idx for idx, char in enumerate(matrix[j]) if char}
49+
if indices_i.issubset(indices_j):
50+
matrix.pop(i)
51+
changed = True
52+
break
53+
if changed:
54+
break
55+
return matrix
56+
57+
# --- Main algorithm ---
58+
# Start with a single column of 'a's for all models.
59+
letter_matrix = [["a"] * num_models]
60+
61+
# For each significant pair, update the letter matrix.
62+
for model_idx1, model_idx2 in significant_index_pairs:
63+
while any(col[model_idx1] and col[model_idx2] for col in letter_matrix):
64+
for col_index, letter_column in enumerate(letter_matrix):
65+
if letter_column[model_idx1] and letter_column[model_idx2]:
66+
new_column = letter_column.copy()
67+
new_column[model_idx1] = ""
68+
letter_column[model_idx2] = ""
69+
letter_matrix[col_index] = letter_column
70+
letter_matrix.append(new_column)
71+
letter_matrix = remove_redundant_columns(letter_matrix)
72+
break # re-check with the while condition
73+
74+
# --- Reassign letters based on sorted columns ---
75+
def first_nonempty_position(column):
76+
for pos, char in enumerate(column):
77+
if char:
78+
return pos
79+
return len(column)
80+
81+
letter_matrix.sort(key=first_nonempty_position)
82+
83+
for idx, column in enumerate(letter_matrix):
84+
replacement_letter = chr(ord("a") + idx)
85+
letter_matrix[idx] = [replacement_letter if char else "" for char in column]
86+
87+
# --- Build final CLD output for each model ---
88+
final_display = []
89+
for model_idx in range(num_models):
90+
letters = "".join(
91+
letter_matrix[col_idx][model_idx]
92+
for col_idx in range(len(letter_matrix))
93+
if letter_matrix[col_idx][model_idx]
94+
)
95+
final_display.append(letters)
96+
97+
return final_display
98+
99+
100+
def compare_success_and_get_cld(
101+
model_name_list: List[str], # [model_0, ...]
102+
success_array_list: List[np.ndarray], # [success_array_for_model_0, ...]
103+
global_confidence_level: float,
104+
max_sample_size_per_model: int,
105+
shuffle: bool,
106+
rng: Optional[np.random.Generator] = None,
107+
verbose: bool = True,
108+
) -> Dict[str, str]:
109+
"""Compares multiple success arrays and returns their Compact Letter Display (CLD)
110+
representation based on pairwise tests with STEP.
111+
112+
Args:
113+
model_name_list: A list of model names.
114+
success_array_list: A list of binary arrays indicating success/failure
115+
for each model.
116+
global_confidence_level: The desired global confidence level for the
117+
multiple comparisons.
118+
max_sample_size_per_model: The maximum sample size to use for comparison
119+
(per model). You must set this number based on your experimental budget
120+
before initiating your statistical analysis.
121+
shuffle: Whether to shuffle the True/False ordering of each success array
122+
before comparison. Set it to False if each True/False outcome is
123+
independent within each array. Set to True if, for example, each array is a
124+
concatenation of results from multiple tasks and you want to measure the
125+
aggregate multi-task performance.
126+
rng: Optional random number generator instance for shuffling. Only used if
127+
shuffle is True.
128+
verbose: Whether to print detailed output. Defaults to True.
129+
Returns:
130+
A dictionary mapping model names to their CLD letters.
131+
"""
132+
if shuffle and rng is None:
133+
raise ValueError("rng must be provided when shuffle is True.")
134+
num_models = len(model_name_list)
135+
# Set up the sequential statistical test.
136+
global_alpha = 1 - global_confidence_level
137+
num_comparisons = num_models * (num_models - 1) // 2
138+
individual_alpha = global_alpha / num_comparisons
139+
individual_confidence_level = 1 - individual_alpha
140+
if verbose:
141+
print("Statistical Test Specs:")
142+
print(" Method: STEP")
143+
print(f" Global Confidence: {round(global_confidence_level, 5)}")
144+
print(f" ({round(individual_confidence_level, 5)} per comparison)")
145+
print(f" Maximum Sample Size per Model: {max_sample_size_per_model}\n")
146+
test = MirroredStepTest(
147+
alternative=Hypothesis.P0LessThanP1,
148+
alpha=individual_alpha,
149+
n_max=max_sample_size_per_model,
150+
)
151+
test.reset()
152+
153+
# Prepare success array per model.
154+
success_array_dict = dict() # model_name -> success_array
155+
for idx in np.arange(num_models):
156+
model = model_name_list[idx]
157+
success_array = success_array_list[idx]
158+
if shuffle:
159+
rng.shuffle(success_array)
160+
success_array_dict[model] = success_array
161+
162+
# Run pairwise comparisons.
163+
comparisons_dict = dict() # (model_name_a, model_name_b) -> Decision
164+
for idx_a in np.arange(num_models):
165+
for idx_b in np.arange(idx_a + 1, num_models):
166+
model_a = model_name_list[idx_a]
167+
model_b = model_name_list[idx_b]
168+
array_a = success_array_dict[model_a]
169+
array_b = success_array_dict[model_b]
170+
len_common = min(len(array_a), len(array_b))
171+
array_a = array_a[:len_common]
172+
array_b = array_b[:len_common]
173+
# Run the test.
174+
test_result = test.run_on_sequence(array_a, array_b)
175+
comparisons_dict[(model_a, model_b)] = test_result.decision
176+
177+
# Compact Letter Display algorithm to summarize results
178+
input_list_to_cld = list()
179+
for key, val in comparisons_dict.items():
180+
if val != Decision.FailToDecide:
181+
input_list_to_cld.append(key)
182+
models_sorted_by_success_rates = [
183+
model
184+
for model, _ in sorted(
185+
success_array_dict.items(),
186+
key=lambda kv_pair: (np.mean(kv_pair[1]) if len(kv_pair[1]) else 0.0),
187+
reverse=True,
188+
)
189+
]
190+
letters_list = compact_letter_display(
191+
input_list_to_cld, models_sorted_by_success_rates
192+
)
193+
if verbose:
194+
print("Statistical Test Results (Compact Letter Display):")
195+
str_padding = max([len(model) for model in models_sorted_by_success_rates])
196+
return_dict = dict()
197+
for letters, model in zip(letters_list, models_sorted_by_success_rates):
198+
return_dict[model] = letters
199+
num_successes = np.sum(success_array_dict[model])
200+
num_trials = len(success_array_dict[model])
201+
if len(success_array_dict[model]) == 0:
202+
empirical_success_rate = 0.0
203+
else:
204+
empirical_success_rate = np.mean(success_array_dict[model])
205+
if verbose:
206+
print(
207+
f" CLD for {model:<{str_padding}}: {letters}\n"
208+
f" Success Rate {num_successes} / {num_trials} = "
209+
f"{round(empirical_success_rate, 3)}",
210+
)
211+
212+
# Ranks are determined if each policy has a unique single letter.
213+
all_order_determined = all([len(letters) == 1 for letters in letters_list]) and len(
214+
set(letters_list)
215+
) == len(model_name_list)
216+
if verbose:
217+
if all_order_determined:
218+
print(
219+
(
220+
"All models separated with global confidence of "
221+
f"{round(global_confidence_level, 5)}."
222+
)
223+
)
224+
else:
225+
print(
226+
(
227+
"Not all models were separated with global confidence of "
228+
f"{round(global_confidence_level, 5)}. Models that share "
229+
"a same letter are not separated from each other with "
230+
"statistical significance. For more information on how to "
231+
"interpret the letters, see: "
232+
"https://en.wikipedia.org/wiki/Compact_letter_display.\n"
233+
)
234+
)
235+
return return_dict
236+
237+
238+
def draw_samples_from_beta_posterior(
239+
success_array: np.ndarray,
240+
rng: np.random.Generator,
241+
num_samples: int = 10000,
242+
alpha_prior: float = 1,
243+
beta_prior: float = 1,
244+
) -> np.ndarray:
245+
"""Draw samples from the beta posterior distribution given a success array.
246+
These samples can be used to estimate the posterior distribution of the
247+
success rate of a Bernoulli process. Note that the default prior parameters
248+
of (1, 1) correspond to a uniform prior.
249+
250+
Args:
251+
success_array: A binary array with True/False indicating success/failure.
252+
rng: A numpy random Generator instance.
253+
num_samples: Optional number of samples to draw. Defaults to 10000.
254+
alpha_prior: Optional alpha parameter of the beta prior. Defaults to 1.
255+
beta_prior: Optional beta parameter of the beta prior. Defaults to 1.
256+
257+
Returns:
258+
Samples drawn from the beta posterior distribution.
259+
"""
260+
n_trials = len(success_array)
261+
n_successes = np.sum(success_array)
262+
n_failures = n_trials - n_successes
263+
posterior = stats.beta(alpha_prior + n_successes, beta_prior + n_failures)
264+
return posterior.rvs(num_samples, random_state=rng)
265+
266+
267+
def plot_model_comparison(
268+
model_name_list: List[str],
269+
success_arrays: List[np.ndarray],
270+
cld_letters: List[str],
271+
rng: np.random.Generator,
272+
output_path: Optional[str] = None,
273+
title: Optional[str] = None,
274+
add_legend: bool = False,
275+
unit_width: int = 6,
276+
height: int = 4,
277+
dpi: int = 100,
278+
) -> Union[None, plt.Figure]:
279+
"""Makes a violin plot of success rate estimates with corresponding CLD letters
280+
for policy comparison.
281+
282+
Args:
283+
model_name_list: A list of model names.
284+
success_arrays: A list of arrays indicating success/failure for each model.
285+
cld_letters: A list of CLD letters corresponding to each model.
286+
rng: A numpy random Generator instance for posterior sampling.
287+
output_path: Optional file path to save the plot. If None, the plot will not
288+
be saved but returned as a matplotlib Figure object. Defaults to None.
289+
title: Optional title for the plot. Defaults to None.
290+
add_legend: Whether to show legend on the plot. Defaults to False.
291+
unit_width: Figure width per model. Defaults to 6.
292+
height: Figure height. Defaults to 4.
293+
dpi: Resolution of the saved plot. Defaults to 100.
294+
295+
Returns:
296+
If output_path is None, returns a matplotlib Figure object containing
297+
the plot. Otherwise, saves the plot to the specified path and returns None.
298+
"""
299+
num_models = len(model_name_list)
300+
301+
posterior_samples = []
302+
means = []
303+
304+
for success_array in success_arrays:
305+
samples = draw_samples_from_beta_posterior(success_array, rng)
306+
posterior_samples.append(samples)
307+
means.append(np.mean(samples))
308+
309+
fig, ax = plt.subplots(figsize=(max(unit_width, num_models), height), dpi=dpi)
310+
311+
cmap = get_cmap("tab10")
312+
colors = [cmap(i % 10) for i in range(num_models)]
313+
314+
parts = ax.violinplot(
315+
posterior_samples,
316+
positions=np.arange(num_models),
317+
showmeans=True,
318+
showmedians=False,
319+
showextrema=False,
320+
widths=0.8,
321+
)
322+
for pc, color in zip(parts["bodies"], colors):
323+
pc.set_facecolor(color)
324+
pc.set_alpha(0.6)
325+
parts["cmeans"].set_color("black")
326+
parts["cmeans"].set_linewidth(0.8)
327+
328+
# Add CLD labels
329+
for i, (x, y, label) in enumerate(zip(np.arange(num_models), means, cld_letters)):
330+
ax.text(
331+
x + 0.15,
332+
y + 0.03,
333+
label,
334+
fontsize=12,
335+
fontweight="bold",
336+
color="black",
337+
verticalalignment="center",
338+
zorder=4,
339+
)
340+
341+
ax.set_xticks(np.arange(num_models))
342+
ax.set_xticklabels(model_name_list, rotation=0, ha="center")
343+
ax.set_ylim(0.0, 1.0)
344+
ax.set_ylabel("Success Rate")
345+
if title is not None:
346+
ax.set_title(title)
347+
if add_legend:
348+
ax.legend(parts["bodies"], model_name_list, loc="best")
349+
plt.tight_layout()
350+
351+
if output_path is not None:
352+
plt.savefig(output_path, dpi=300)
353+
plt.close()
354+
print(f"Saved a PNG plot to {output_path}")
355+
else:
356+
return fig

0 commit comments

Comments
 (0)