|
2 | 2 | Visualization tools for token sequences and particle paths. |
3 | 3 |
|
4 | 4 | Provides multiple visualization types: |
5 | | -- Sankey-like diagrams for token generation paths |
| 5 | +- Bumplot diagrams for token generation paths |
6 | 6 | - Probability heatmaps |
7 | 7 | - Divergence plots |
8 | 8 | - Token distribution visualizations |
@@ -80,115 +80,19 @@ def visualize( |
80 | 80 | particles: List[Particle], |
81 | 81 | prompt: str, |
82 | 82 | output_path: Optional[str] = None, |
83 | | - title: Optional[str] = None, |
84 | | - highlight_most_probable: bool = True |
| 83 | + **kwargs |
85 | 84 | ) -> plt.Figure: |
86 | 85 | """ |
87 | | - Create a Sankey-like visualization of token sequences. |
88 | | - |
89 | | - Args: |
90 | | - particles: List of particles from the filter |
91 | | - prompt: Initial prompt |
92 | | - output_path: Path to save the figure |
93 | | - title: Plot title |
94 | | - highlight_most_probable: Whether to highlight the most probable path |
95 | | - |
96 | | - Returns: |
97 | | - Matplotlib figure |
| 86 | + Deprecated: Use visualize_bumplot instead. |
| 87 | + This method now redirects to visualize_bumplot for backwards compatibility. |
98 | 88 | """ |
99 | | - fig, ax = plt.subplots(figsize=self.figsize) |
100 | | - |
101 | | - if title is None: |
102 | | - title = f"Token Generation Paths: \"{prompt}\"" |
103 | | - |
104 | | - # Find the maximum sequence length |
105 | | - max_length = min( |
106 | | - max(len(p.tokens) for p in particles), |
107 | | - self.max_tokens_display |
| 89 | + return self.visualize_bumplot( |
| 90 | + particles=particles, |
| 91 | + prompt=prompt, |
| 92 | + output_path=output_path, |
| 93 | + **kwargs |
108 | 94 | ) |
109 | | - |
110 | | - # Find the most probable particle |
111 | | - most_probable_idx = None |
112 | | - if highlight_most_probable and particles: |
113 | | - log_probs = [p.log_prob for p in particles] |
114 | | - most_probable_idx = np.argmax(log_probs) |
115 | | - |
116 | | - # Draw paths for each particle |
117 | | - for i, particle in enumerate(particles): |
118 | | - # Create path coordinates |
119 | | - x_coords = [] |
120 | | - y_coords = [] |
121 | | - |
122 | | - for t in range(min(len(particle.tokens), max_length)): |
123 | | - x_coords.append(t) |
124 | | - # Map token to y-position (spread tokens vertically) |
125 | | - # Use token ID modulo to distribute tokens |
126 | | - y_pos = (particle.tokens[t] % 1000) / 10.0 - 50.0 |
127 | | - y_coords.append(y_pos) |
128 | | - |
129 | | - # Determine if this is the most probable path |
130 | | - is_most_probable = (i == most_probable_idx) |
131 | | - |
132 | | - # Set line properties |
133 | | - if is_most_probable: |
134 | | - color = 'red' |
135 | | - alpha = 1.0 |
136 | | - linewidth = 2.0 |
137 | | - zorder = 1000 # Draw on top |
138 | | - else: |
139 | | - color = 'black' |
140 | | - alpha = self.alpha |
141 | | - linewidth = self.line_width |
142 | | - zorder = 1 |
143 | | - |
144 | | - # Draw the path |
145 | | - if len(x_coords) > 1: |
146 | | - ax.plot(x_coords, y_coords, |
147 | | - color=color, alpha=alpha, linewidth=linewidth, |
148 | | - zorder=zorder) |
149 | | - |
150 | | - # Add token labels for the most probable sequence |
151 | | - if highlight_most_probable and most_probable_idx is not None: |
152 | | - most_probable_particle = particles[most_probable_idx] |
153 | | - |
154 | | - # Add text below the plot showing the most probable sequence |
155 | | - text = self.tokenizer.decode(most_probable_particle.tokens[:max_length]) |
156 | | - ax.text(0.5, -0.15, f"Most probable sequence: {text}", |
157 | | - transform=ax.transAxes, |
158 | | - ha='center', va='top', |
159 | | - fontsize=10, |
160 | | - bbox=dict(boxstyle="round,pad=0.5", facecolor='white', alpha=0.8)) |
161 | | - |
162 | | - # Set axis properties |
163 | | - ax.set_xlim(-0.5, max_length - 0.5) |
164 | | - ax.set_ylim(-60, 60) |
165 | | - ax.set_xlabel("Token Position", fontsize=12) |
166 | | - ax.set_ylabel("Token Space", fontsize=12) |
167 | | - ax.set_title(title, fontsize=14, fontweight='bold') |
168 | | - |
169 | | - # Remove y-axis ticks (too many tokens to label) |
170 | | - ax.set_yticks([]) |
171 | | - |
172 | | - # Add grid |
173 | | - ax.grid(True, axis='x', alpha=0.3) |
174 | | - |
175 | | - # Add legend |
176 | | - ax.text( |
177 | | - 0.02, 0.98, |
178 | | - f"Particles: {len(particles)}", |
179 | | - transform=ax.transAxes, |
180 | | - fontsize=10, |
181 | | - verticalalignment='top', |
182 | | - bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5) |
183 | | - ) |
184 | | - |
185 | | - plt.tight_layout() |
186 | | - |
187 | | - if output_path: |
188 | | - plt.savefig(output_path, dpi=self.save_dpi, bbox_inches='tight') |
189 | 95 |
|
190 | | - return fig |
191 | | - |
192 | 96 | def visualize_probability_heatmap( |
193 | 97 | self, |
194 | 98 | particles: List[Particle], |
|
0 commit comments