Skip to content

Commit a005c42

Browse files
committed
Remove Sankey diagram and fix alpha transparency
Two major improvements: 1. REMOVED SANKEY DIAGRAM: - Deleted entire Sankey visualization method (was redundant) - visualize() now redirects to visualize_bumplot for compatibility - Cleaned up all orphaned Sankey code - Updated documentation to reference bumplot instead 2. FIXED ALPHA TRANSPARENCY: - Forced alpha using RGBA color specification - Now properly shows transparency at alpha=0.1 - Overlapping curves clearly visible through each other - Uses matplotlib.colors.to_rgba() then modifies alpha channel The bumplot is now the sole trajectory visualization method, with properly working transparency for dense particle paths.
1 parent 4368f7d commit a005c42

File tree

2 files changed

+18
-111
lines changed

2 files changed

+18
-111
lines changed

code/quantum_conversations/custom_bumplot.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,15 @@ def create_custom_bumplot(
9999
x_segment = np.concatenate([x_segment, x_overlap])
100100
y_segment = np.concatenate([y_segment, y_overlap])
101101

102-
# Plot this segment with explicit alpha
103-
# Force matplotlib to use alpha blending
104-
line, = ax.plot(x_segment, y_segment, color=color,
105-
linewidth=linewidth, solid_capstyle='round',
106-
solid_joinstyle='round', zorder=100)
107-
line.set_alpha(alpha) # Explicitly set alpha on line object
102+
# Plot this segment with forced transparency
103+
# Use RGBA color to ensure alpha is applied
104+
import matplotlib.colors as mcolors
105+
rgba_color = list(mcolors.to_rgba(color))
106+
rgba_color[3] = alpha # Force alpha channel
107+
108+
ax.plot(x_segment, y_segment, color=rgba_color,
109+
linewidth=linewidth, solid_capstyle='round',
110+
solid_joinstyle='round', zorder=100)
108111

109112
except Exception as e:
110113
# Fallback to linear segments

code/quantum_conversations/visualizer.py

Lines changed: 9 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Visualization tools for token sequences and particle paths.
33
44
Provides multiple visualization types:
5-
- Sankey-like diagrams for token generation paths
5+
- Bumplot diagrams for token generation paths
66
- Probability heatmaps
77
- Divergence plots
88
- Token distribution visualizations
@@ -80,115 +80,19 @@ def visualize(
8080
particles: List[Particle],
8181
prompt: str,
8282
output_path: Optional[str] = None,
83-
title: Optional[str] = None,
84-
highlight_most_probable: bool = True
83+
**kwargs
8584
) -> plt.Figure:
8685
"""
87-
Create a Sankey-like visualization of token sequences.
88-
89-
Args:
90-
particles: List of particles from the filter
91-
prompt: Initial prompt
92-
output_path: Path to save the figure
93-
title: Plot title
94-
highlight_most_probable: Whether to highlight the most probable path
95-
96-
Returns:
97-
Matplotlib figure
86+
Deprecated: Use visualize_bumplot instead.
87+
This method now redirects to visualize_bumplot for backwards compatibility.
9888
"""
99-
fig, ax = plt.subplots(figsize=self.figsize)
100-
101-
if title is None:
102-
title = f"Token Generation Paths: \"{prompt}\""
103-
104-
# Find the maximum sequence length
105-
max_length = min(
106-
max(len(p.tokens) for p in particles),
107-
self.max_tokens_display
89+
return self.visualize_bumplot(
90+
particles=particles,
91+
prompt=prompt,
92+
output_path=output_path,
93+
**kwargs
10894
)
109-
110-
# Find the most probable particle
111-
most_probable_idx = None
112-
if highlight_most_probable and particles:
113-
log_probs = [p.log_prob for p in particles]
114-
most_probable_idx = np.argmax(log_probs)
115-
116-
# Draw paths for each particle
117-
for i, particle in enumerate(particles):
118-
# Create path coordinates
119-
x_coords = []
120-
y_coords = []
121-
122-
for t in range(min(len(particle.tokens), max_length)):
123-
x_coords.append(t)
124-
# Map token to y-position (spread tokens vertically)
125-
# Use token ID modulo to distribute tokens
126-
y_pos = (particle.tokens[t] % 1000) / 10.0 - 50.0
127-
y_coords.append(y_pos)
128-
129-
# Determine if this is the most probable path
130-
is_most_probable = (i == most_probable_idx)
131-
132-
# Set line properties
133-
if is_most_probable:
134-
color = 'red'
135-
alpha = 1.0
136-
linewidth = 2.0
137-
zorder = 1000 # Draw on top
138-
else:
139-
color = 'black'
140-
alpha = self.alpha
141-
linewidth = self.line_width
142-
zorder = 1
143-
144-
# Draw the path
145-
if len(x_coords) > 1:
146-
ax.plot(x_coords, y_coords,
147-
color=color, alpha=alpha, linewidth=linewidth,
148-
zorder=zorder)
149-
150-
# Add token labels for the most probable sequence
151-
if highlight_most_probable and most_probable_idx is not None:
152-
most_probable_particle = particles[most_probable_idx]
153-
154-
# Add text below the plot showing the most probable sequence
155-
text = self.tokenizer.decode(most_probable_particle.tokens[:max_length])
156-
ax.text(0.5, -0.15, f"Most probable sequence: {text}",
157-
transform=ax.transAxes,
158-
ha='center', va='top',
159-
fontsize=10,
160-
bbox=dict(boxstyle="round,pad=0.5", facecolor='white', alpha=0.8))
161-
162-
# Set axis properties
163-
ax.set_xlim(-0.5, max_length - 0.5)
164-
ax.set_ylim(-60, 60)
165-
ax.set_xlabel("Token Position", fontsize=12)
166-
ax.set_ylabel("Token Space", fontsize=12)
167-
ax.set_title(title, fontsize=14, fontweight='bold')
168-
169-
# Remove y-axis ticks (too many tokens to label)
170-
ax.set_yticks([])
171-
172-
# Add grid
173-
ax.grid(True, axis='x', alpha=0.3)
174-
175-
# Add legend
176-
ax.text(
177-
0.02, 0.98,
178-
f"Particles: {len(particles)}",
179-
transform=ax.transAxes,
180-
fontsize=10,
181-
verticalalignment='top',
182-
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)
183-
)
184-
185-
plt.tight_layout()
186-
187-
if output_path:
188-
plt.savefig(output_path, dpi=self.save_dpi, bbox_inches='tight')
18995

190-
return fig
191-
19296
def visualize_probability_heatmap(
19397
self,
19498
particles: List[Particle],

0 commit comments

Comments
 (0)