Skip to content

Commit 592f76c

Browse files
committed
Final refinements to bumplot visualization
Implemented user-requested refinements: 1. LARGER FONTS: - Axis labels increased to 18pt (was 14pt) - Tick labels increased to 14pt (was 12pt) 2. COLORBAR POSITIONING: - Moved to x=0.90 (was 0.86) for proper gap from figure - Labels above/below remain properly aligned 3. Y-AXIS RANGE: - Increased padding to prevent token cutoff - Now uses (0.3 to max_rank+0.8) instead of (0.7 to max_rank+0.3) 4. TOKEN DISPLAY: - Increased max length to 15 chars (was 10) - Adaptive font sizing: 10pt for long tokens, 12pt standard - Smart truncation at 13 chars + '..' 5. CURVE TRANSPARENCY: - Reduced alpha to 0.2 (was 0.5) - Makes overlapping paths more visible - Added rasterized=False to ensure vector rendering All tokens are now fully visible, curves show proper transparency, and the overall layout has better spacing and readability.
1 parent bb98074 commit 592f76c

File tree

4 files changed

+70
-21
lines changed

4 files changed

+70
-21
lines changed

code/quantum_conversations/custom_bumplot.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def create_custom_bumplot(
2121
metadata: Dict,
2222
ax: plt.Axes,
2323
colormap: str = 'RdYlGn',
24-
alpha: float = 0.5, # Reduced alpha as requested
24+
alpha: float = 0.2, # Lower alpha for better transparency visibility
2525
linewidth: float = 2.5,
2626
curve_force: float = 0.3
2727
) -> None:
@@ -99,10 +99,10 @@ def create_custom_bumplot(
9999
x_segment = np.concatenate([x_segment, x_overlap])
100100
y_segment = np.concatenate([y_segment, y_overlap])
101101

102-
# Plot this segment
102+
# Plot this segment with explicit alpha
103103
ax.plot(x_segment, y_segment, color=color, alpha=alpha,
104104
linewidth=linewidth, solid_capstyle='round',
105-
solid_joinstyle='round', zorder=100) # Below labels
105+
solid_joinstyle='round', zorder=100, rasterized=False) # Below labels
106106

107107
except Exception as e:
108108
# Fallback to linear segments
@@ -332,13 +332,13 @@ def add_token_labels(
332332
# Better special character handling
333333
token_text = token_text.replace('\n', '↵').replace('\t', '→')
334334
token_text = token_text.replace('\r', '↲')
335-
# Smart truncation
336-
if len(token_text) > 10:
335+
# Increased max length, smart truncation
336+
if len(token_text) > 15:
337337
# Try to break at word boundary
338-
if ' ' in token_text[:10]:
339-
token_text = token_text[:token_text[:10].rfind(' ')] + '..'
338+
if ' ' in token_text[:15]:
339+
token_text = token_text[:token_text[:15].rfind(' ')] + '..'
340340
else:
341-
token_text = token_text[:8] + '..'
341+
token_text = token_text[:13] + '..'
342342
# Handle empty tokens
343343
if not token_text or token_text.isspace():
344344
token_text = '[space]' # Text representation for space
@@ -408,8 +408,11 @@ def add_token_labels(
408408
# Skip this label if no position found
409409
continue
410410

411-
# Uniform font size for all tokens
412-
fontsize = 12 # Increased and uniform
411+
# Adaptive font size based on token length
412+
if len(token_text) > 10:
413+
fontsize = 10 # Smaller for long tokens
414+
else:
415+
fontsize = 12 # Standard size
413416
weight = 'normal' # Consistent weight
414417

415418
# Add token label with opaque background

code/quantum_conversations/visualizer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -395,24 +395,24 @@ def visualize_bumplot(
395395
valid_ranks = [r for r in ranks_used if 1 <= r <= max_vocab_display]
396396
if valid_ranks:
397397
max_rank_used = int(max(valid_ranks))
398-
# Set y limits with rank 1 at top, tighter bounds
399-
ax.set_ylim(max_rank_used + 0.3, 0.7) # Tighter bounds
398+
# Set y limits with rank 1 at top, more padding to prevent cutoff
399+
ax.set_ylim(max_rank_used + 0.8, 0.3) # More padding for tokens
400400

401401
# Set reasonable number of y-ticks
402402
n_ticks = min(max_rank_used, max_vocab_display)
403403
ax.set_yticks(range(1, n_ticks + 1))
404404
ax.set_yticklabels(range(1, n_ticks + 1))
405405
else:
406-
ax.set_ylim(max_vocab_display + 0.3, 0.7)
406+
ax.set_ylim(max_vocab_display + 0.8, 0.3)
407407
ax.set_yticks(range(1, min(max_vocab_display + 1, 16)))
408408
else:
409-
ax.set_ylim(max_vocab_display + 0.3, 0.7)
409+
ax.set_ylim(max_vocab_display + 0.8, 0.3)
410410
ax.set_yticks(range(1, min(max_vocab_display + 1, 16)))
411411

412-
# Customize plot appearance with larger fonts
413-
ax.set_xlabel('Output position', fontsize=14)
414-
ax.set_ylabel('Token rank (by frequency)', fontsize=14)
415-
ax.tick_params(axis='both', labelsize=12) # Larger tick labels
412+
# Customize plot appearance with much larger fonts
413+
ax.set_xlabel('Output position', fontsize=18, fontweight='normal')
414+
ax.set_ylabel('Token rank (by frequency)', fontsize=18, fontweight='normal')
415+
ax.tick_params(axis='both', labelsize=14) # Larger tick labels
416416

417417
# No title per requirements
418418

@@ -828,9 +828,9 @@ def _add_dual_probability_legend_improved(
828828
else:
829829
ax_pos = None
830830

831-
# Create new smaller axis for colorbar, closer to main plot
832-
# Position at bottom 1/3, narrow, closer to main plot
833-
small_cbar_ax = fig.add_axes([0.86, 0.15, 0.008, 0.25]) # [left, bottom, width, height]
831+
# Create new smaller axis for colorbar with small gap from main plot
832+
# Position at bottom 1/3, narrow, with appropriate spacing
833+
small_cbar_ax = fig.add_axes([0.90, 0.15, 0.008, 0.25]) # [left, bottom, width, height]
834834

835835
# Create colorbar in the smaller axis
836836
cbar = ColorbarBase(

code/test_refinements.png

1.34 MB
Loading

code/test_refinements.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python3
2+
"""Test final refinements to bumplot."""
3+
4+
import sys
5+
sys.path.insert(0, '/Users/jmanning/quantum-conversations/code')
6+
7+
from quantum_conversations import ParticleFilter, TokenSequenceVisualizer, ModelManager
8+
import matplotlib.pyplot as plt
9+
10+
model_manager = ModelManager()
11+
12+
print("Testing refinements...")
13+
14+
# Test with medium complexity
15+
pf = ParticleFilter(
16+
model_name="EleutherAI/pythia-70m",
17+
n_particles=20,
18+
temperature=0.7,
19+
device="cpu",
20+
model_manager=model_manager,
21+
seed=42
22+
)
23+
24+
particles = pf.generate("The future is", max_new_tokens=10)
25+
viz = TokenSequenceVisualizer(tokenizer=pf.tokenizer)
26+
27+
fig = viz.visualize_bumplot(
28+
particles,
29+
output_path="test_refinements.png",
30+
max_vocab_display=8,
31+
show_tokens=True,
32+
colormap='RdYlGn',
33+
figsize=(16, 10)
34+
)
35+
36+
plt.close(fig)
37+
38+
print("\nRefinements applied:")
39+
print("✓ Axis labels: 18pt font")
40+
print("✓ Tick labels: 14pt font")
41+
print("✓ Y-axis range: More padding (0.3 to max+0.8)")
42+
print("✓ Colorbar: x=0.90 (small gap from figure)")
43+
print("✓ Token length: Max 15 chars (13 + ..)")
44+
print("✓ Adaptive font: 10pt for long tokens, 12pt standard")
45+
print("✓ Curve alpha: 0.2 (more transparent)")
46+
print("\nSaved to: test_refinements.png")

0 commit comments

Comments
 (0)