Skip to content

Commit e911777

Browse files
committed
feat(zero_shot): add win rate chart generator
- Add WinRateChartGenerator class for visualizing model rankings - Support customizable chart styles, colors, and annotations - Add matplotlib dependency to pyproject.toml - Update schema with ChartConfig dataclass - Integrate chart generation into zero_shot_pipeline
1 parent ccf68cf commit e911777

File tree

6 files changed

+376
-2
lines changed

6 files changed

+376
-2
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# -*- coding: utf-8 -*-
2+
"""Chart generator for zero-shot evaluation results.
3+
4+
This module provides visualization capabilities for evaluation results,
5+
generating beautiful bar charts to display model win rates.
6+
"""
7+
8+
from pathlib import Path
9+
from typing import TYPE_CHECKING, List, Optional, Tuple
10+
11+
from loguru import logger
12+
13+
if TYPE_CHECKING:
14+
from cookbooks.zero_shot_evaluation.schema import ChartConfig
15+
16+
17+
class WinRateChartGenerator:
18+
"""Generator for win rate comparison charts.
19+
20+
Creates visually appealing bar charts showing model rankings
21+
based on pairwise evaluation results.
22+
23+
Attributes:
24+
config: Chart configuration options
25+
26+
Example:
27+
>>> generator = WinRateChartGenerator(config)
28+
>>> path = generator.generate(
29+
... rankings=[("GPT-4", 0.73), ("Claude", 0.65)],
30+
... output_dir="./results",
31+
... task_description="Translation evaluation",
32+
... )
33+
"""
34+
35+
# Color palette - inspired by modern data visualization
36+
ACCENT_COLOR = "#FF6B35" # Vibrant orange for best model
37+
ACCENT_HATCH = "///" # Diagonal stripes pattern
38+
BAR_COLORS = [
39+
"#4A4A4A", # Dark gray
40+
"#6B6B6B", # Medium gray
41+
"#8C8C8C", # Light gray
42+
"#ADADAD", # Lighter gray
43+
"#CECECE", # Very light gray
44+
]
45+
46+
def __init__(self, config: Optional["ChartConfig"] = None):
47+
"""Initialize chart generator.
48+
49+
Args:
50+
config: Chart configuration. Uses defaults if not provided.
51+
"""
52+
self.config = config
53+
54+
def _configure_cjk_font(self, plt, font_manager) -> Optional[str]:
55+
"""Configure matplotlib to support CJK (Chinese/Japanese/Korean) characters.
56+
57+
Attempts to find and use a system font that supports CJK characters.
58+
Falls back gracefully if no suitable font is found.
59+
60+
Returns:
61+
Font name if found, None otherwise
62+
"""
63+
# Common CJK fonts on different platforms (simplified Chinese priority)
64+
cjk_fonts = [
65+
# macOS - Simplified Chinese (verified available)
66+
"Hiragino Sans GB",
67+
"Songti SC",
68+
"Kaiti SC",
69+
"Heiti SC",
70+
"Lantinghei SC",
71+
"PingFang SC",
72+
"STFangsong",
73+
# Windows
74+
"Microsoft YaHei",
75+
"SimHei",
76+
"SimSun",
77+
# Linux
78+
"Noto Sans CJK SC",
79+
"WenQuanYi Micro Hei",
80+
"Droid Sans Fallback",
81+
# Generic
82+
"Arial Unicode MS",
83+
]
84+
85+
# Get available fonts
86+
available_fonts = {f.name for f in font_manager.fontManager.ttflist}
87+
88+
# Find the first available CJK font
89+
for font_name in cjk_fonts:
90+
if font_name in available_fonts:
91+
plt.rcParams["font.sans-serif"] = [font_name] + plt.rcParams.get("font.sans-serif", [])
92+
plt.rcParams["axes.unicode_minus"] = False # Fix minus sign display
93+
logger.debug(f"Using CJK font: {font_name}")
94+
return font_name
95+
96+
# No CJK font found, log warning
97+
logger.warning(
98+
"No CJK font found. Chinese characters may not display correctly. "
99+
"Consider installing a CJK font like 'Noto Sans CJK SC'."
100+
)
101+
return None
102+
103+
def generate(
104+
self,
105+
rankings: List[Tuple[str, float]],
106+
output_dir: str,
107+
task_description: Optional[str] = None,
108+
total_queries: int = 0,
109+
total_comparisons: int = 0,
110+
) -> Optional[Path]:
111+
"""Generate win rate bar chart.
112+
113+
Args:
114+
rankings: List of (model_name, win_rate) tuples, sorted by win rate
115+
output_dir: Directory to save the chart
116+
task_description: Task description for subtitle
117+
total_queries: Number of queries evaluated
118+
total_comparisons: Number of pairwise comparisons
119+
120+
Returns:
121+
Path to saved chart file, or None if generation failed
122+
"""
123+
if not rankings:
124+
logger.warning("No rankings data to visualize")
125+
return None
126+
127+
try:
128+
import matplotlib.patches as mpatches
129+
import matplotlib.pyplot as plt
130+
import numpy as np
131+
from matplotlib import font_manager
132+
except ImportError:
133+
logger.warning("matplotlib not installed. Install with: pip install matplotlib")
134+
return None
135+
136+
# Extract config values
137+
figsize = self.config.figsize if self.config else (12, 7)
138+
dpi = self.config.dpi if self.config else 150
139+
fmt = self.config.format if self.config else "png"
140+
show_values = self.config.show_values if self.config else True
141+
highlight_best = self.config.highlight_best if self.config else True
142+
custom_title = self.config.title if self.config else None
143+
144+
# Prepare data (already sorted high to low)
145+
model_names = [r[0] for r in rankings]
146+
win_rates = [r[1] * 100 for r in rankings] # Convert to percentage
147+
n_models = len(model_names)
148+
149+
# Setup figure with modern styling (MUST be before font config)
150+
plt.style.use("seaborn-v0_8-whitegrid")
151+
152+
# Configure font for CJK (Chinese/Japanese/Korean) support
153+
# This MUST be after plt.style.use() as style resets font settings
154+
self._configure_cjk_font(plt, font_manager)
155+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
156+
157+
# Create bar positions
158+
x_pos = np.arange(n_models)
159+
bar_width = 0.6
160+
161+
# Determine colors for each bar
162+
colors = []
163+
edge_colors = []
164+
hatches = []
165+
166+
for i in range(n_models):
167+
if i == 0 and highlight_best:
168+
# Best model gets accent color with hatch pattern
169+
colors.append(self.ACCENT_COLOR)
170+
edge_colors.append(self.ACCENT_COLOR)
171+
hatches.append(self.ACCENT_HATCH)
172+
else:
173+
# Other models get grayscale
174+
color_idx = min(i - 1, len(self.BAR_COLORS) - 1) if highlight_best else min(i, len(self.BAR_COLORS) - 1)
175+
colors.append(self.BAR_COLORS[color_idx])
176+
edge_colors.append(self.BAR_COLORS[color_idx])
177+
hatches.append("")
178+
179+
# Draw bars
180+
bars = ax.bar(
181+
x_pos,
182+
win_rates,
183+
width=bar_width,
184+
color=colors,
185+
edgecolor=edge_colors,
186+
linewidth=1.5,
187+
zorder=3,
188+
)
189+
190+
# Add hatch pattern to best model
191+
if highlight_best and n_models > 0:
192+
bars[0].set_hatch(self.ACCENT_HATCH)
193+
bars[0].set_edgecolor("white")
194+
195+
# Add value labels on top of bars
196+
if show_values:
197+
for i, (bar, rate) in enumerate(zip(bars, win_rates)):
198+
height = bar.get_height()
199+
ax.annotate(
200+
f"{rate:.1f}",
201+
xy=(bar.get_x() + bar.get_width() / 2, height),
202+
xytext=(0, 5),
203+
textcoords="offset points",
204+
ha="center",
205+
va="bottom",
206+
fontsize=12,
207+
fontweight="bold",
208+
color="#333333",
209+
)
210+
211+
# Customize axes
212+
ax.set_xticks(x_pos)
213+
ax.set_xticklabels(model_names, fontsize=11, fontweight="medium")
214+
ax.set_ylabel("Win Rate (%)", fontsize=12, fontweight="medium", labelpad=10)
215+
ax.set_ylim(0, min(100, max(win_rates) * 1.15)) # Add headroom for labels
216+
217+
# Remove top and right spines
218+
ax.spines["top"].set_visible(False)
219+
ax.spines["right"].set_visible(False)
220+
ax.spines["left"].set_color("#CCCCCC")
221+
ax.spines["bottom"].set_color("#CCCCCC")
222+
223+
# Customize grid
224+
ax.yaxis.grid(True, linestyle="--", alpha=0.5, color="#DDDDDD", zorder=0)
225+
ax.xaxis.grid(False)
226+
227+
# Title
228+
title = custom_title or "Model Win Rate Comparison"
229+
ax.set_title(title, fontsize=16, fontweight="bold", pad=20, color="#333333")
230+
231+
# Subtitle with evaluation info
232+
subtitle_parts = []
233+
if task_description:
234+
# Truncate long descriptions
235+
desc = task_description[:60] + "..." if len(task_description) > 60 else task_description
236+
subtitle_parts.append(f"Task: {desc}")
237+
if total_queries > 0:
238+
subtitle_parts.append(f"Queries: {total_queries}")
239+
if total_comparisons > 0:
240+
subtitle_parts.append(f"Comparisons: {total_comparisons}")
241+
242+
if subtitle_parts:
243+
subtitle = " | ".join(subtitle_parts)
244+
ax.text(
245+
0.5,
246+
1.02,
247+
subtitle,
248+
transform=ax.transAxes,
249+
ha="center",
250+
va="bottom",
251+
fontsize=10,
252+
color="#666666",
253+
style="italic",
254+
)
255+
256+
# Create legend
257+
legend_elements = []
258+
if highlight_best and n_models > 0:
259+
best_patch = mpatches.Patch(
260+
facecolor=self.ACCENT_COLOR,
261+
edgecolor="white",
262+
hatch=self.ACCENT_HATCH,
263+
label=f"Best: {model_names[0]}",
264+
)
265+
legend_elements.append(best_patch)
266+
267+
if legend_elements:
268+
ax.legend(
269+
handles=legend_elements,
270+
loc="upper right",
271+
frameon=True,
272+
framealpha=0.9,
273+
fontsize=10,
274+
)
275+
276+
# Tight layout
277+
plt.tight_layout()
278+
279+
# Save chart
280+
output_path = Path(output_dir)
281+
output_path.mkdir(parents=True, exist_ok=True)
282+
chart_file = output_path / f"win_rate_chart.{fmt}"
283+
284+
plt.savefig(
285+
chart_file,
286+
format=fmt,
287+
dpi=dpi,
288+
bbox_inches="tight",
289+
facecolor="white",
290+
edgecolor="none",
291+
)
292+
plt.close(fig)
293+
294+
logger.info(f"Win rate chart saved to {chart_file}")
295+
return chart_file

cookbooks/zero_shot_evaluation/schema.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,25 @@ class OutputConfig(BaseModel):
9292
output_dir: str = Field(default="./evaluation_results", description="Output directory")
9393

9494

95+
class ChartConfig(BaseModel):
96+
"""Chart generation configuration."""
97+
98+
enabled: bool = Field(default=True, description="Whether to generate win rate chart")
99+
title: Optional[str] = Field(default=None, description="Chart title (auto-generated if not set)")
100+
figsize: tuple = Field(default=(12, 7), description="Figure size (width, height) in inches")
101+
dpi: int = Field(default=150, ge=72, le=300, description="Image resolution")
102+
format: Literal["png", "svg", "pdf"] = Field(default="png", description="Output format")
103+
show_values: bool = Field(default=True, description="Show values on top of bars")
104+
highlight_best: bool = Field(default=True, description="Highlight the best model with accent color")
105+
106+
95107
class ReportConfig(BaseModel):
96108
"""Report generation configuration."""
97109

98110
enabled: bool = Field(default=False, description="Whether to generate report")
99111
language: Literal["zh", "en"] = Field(default="zh", description="Report language: zh | en")
100112
include_examples: int = Field(default=3, ge=1, le=10, description="Examples per section")
113+
chart: ChartConfig = Field(default_factory=ChartConfig, description="Chart configuration")
101114

102115

103116
class ZeroShotConfig(BaseModel):

cookbooks/zero_shot_evaluation/zero_shot_pipeline.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from loguru import logger
2424
from pydantic import BaseModel, Field
2525

26+
from cookbooks.zero_shot_evaluation.chart_generator import WinRateChartGenerator
2627
from cookbooks.zero_shot_evaluation.query_generator import QueryGenerator
2728
from cookbooks.zero_shot_evaluation.response_collector import ResponseCollector
2829
from cookbooks.zero_shot_evaluation.schema import (
@@ -702,6 +703,10 @@ async def evaluate(
702703
if self.config.report.enabled:
703704
await self._generate_and_save_report(result)
704705

706+
# Step 7: Generate win rate chart if enabled
707+
if self.config.report.chart.enabled:
708+
self._generate_win_rate_chart(result)
709+
705710
return result
706711

707712
async def _generate_and_save_report(self, result: EvaluationResult) -> None:
@@ -728,6 +733,24 @@ async def _generate_and_save_report(self, result: EvaluationResult) -> None:
728733
f.write(report)
729734
logger.info(f"Report saved to {report_path}")
730735

736+
def _generate_win_rate_chart(self, result: EvaluationResult) -> None:
737+
"""Generate and save win rate comparison chart."""
738+
logger.info("Step 7: Generating win rate chart...")
739+
740+
chart_config = self.config.report.chart
741+
generator = WinRateChartGenerator(config=chart_config)
742+
743+
chart_path = generator.generate(
744+
rankings=result.rankings,
745+
output_dir=self.config.output.output_dir,
746+
task_description=self.config.task.description,
747+
total_queries=result.total_queries,
748+
total_comparisons=result.total_comparisons,
749+
)
750+
751+
if chart_path:
752+
logger.info(f"Win rate chart saved to {chart_path}")
753+
731754
def _display_results(self, result: EvaluationResult) -> None:
732755
"""Display evaluation results with formatted output."""
733756
endpoint_names = list(self.config.target_endpoints.keys())

0 commit comments

Comments
 (0)