Skip to content

Commit 4bf081e

Browse files
bbchoclaude
andcommitted
Fix continuous color gradient for geom_path and geom_line
Plotly's line.color only accepts a single value, not an array, so continuous color mapping on line geoms was silently broken. This fix uses segment-based rendering with Scattergl (WebGL) for performance, drawing each line segment with its own interpolated color. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent cd61698 commit 4bf081e

File tree

2 files changed

+182
-3
lines changed

2 files changed

+182
-3
lines changed

ggplotly/scales/scale_color_gradient.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,17 @@ def __init__(self, low="#132B43", high="#56B1F7", name=None, limits=None,
6161

6262
def apply(self, fig):
6363
"""
64-
Apply the color gradient to markers in the figure.
64+
Apply the color gradient to markers and line segments in the figure.
6565
6666
Parameters:
6767
fig (Figure): Plotly figure object.
6868
"""
69+
new_colorscale = [[0, self.low], [1, self.high]]
70+
6971
for trace in fig.data:
72+
# Handle marker-based traces (scatter points, etc.)
7073
if hasattr(trace, 'marker') and trace.marker is not None:
71-
trace.marker.colorscale = [[0, self.low], [1, self.high]]
74+
trace.marker.colorscale = new_colorscale
7275

7376
# Apply limits if specified
7477
if self.limits is not None:
@@ -89,3 +92,58 @@ def apply(self, fig):
8992
trace.marker.showscale = True
9093
else:
9194
trace.marker.showscale = False
95+
96+
# Handle line gradient segments (created by ContinuousColorTraceBuilder)
97+
if hasattr(trace, 'meta') and trace.meta:
98+
meta = trace.meta
99+
if isinstance(meta, dict) and meta.get('_ggplotly_line_gradient'):
100+
t_norm = meta.get('_color_norm', 0)
101+
new_color = self._interpolate_color(new_colorscale, t_norm)
102+
trace.line.color = new_color
103+
104+
@staticmethod
105+
def _interpolate_color(colorscale, t):
106+
"""
107+
Interpolate between colorscale endpoints.
108+
109+
Parameters:
110+
colorscale: List of [position, color] pairs
111+
t: Normalized value between 0 and 1
112+
113+
Returns:
114+
str: Interpolated RGB color string
115+
"""
116+
t = max(0, min(1, t)) # Clamp to [0, 1]
117+
118+
low_color = colorscale[0][1]
119+
high_color = colorscale[1][1]
120+
121+
# Parse color to RGB (handles hex and named colors)
122+
def color_to_rgb(color):
123+
if color.startswith('#'):
124+
color = color.lstrip('#')
125+
return tuple(int(color[i:i + 2], 16) for i in (0, 2, 4))
126+
elif color.startswith('rgb'):
127+
# Parse rgb(r, g, b) format
128+
import re
129+
match = re.match(r'rgb\((\d+),\s*(\d+),\s*(\d+)\)', color)
130+
if match:
131+
return tuple(int(x) for x in match.groups())
132+
# Fallback for named colors - approximate mapping
133+
named_colors = {
134+
'blue': (0, 0, 255), 'red': (255, 0, 0), 'green': (0, 128, 0),
135+
'white': (255, 255, 255), 'black': (0, 0, 0),
136+
'yellow': (255, 255, 0), 'orange': (255, 165, 0),
137+
'purple': (128, 0, 128), 'cyan': (0, 255, 255),
138+
}
139+
return named_colors.get(color.lower(), (128, 128, 128))
140+
141+
low_rgb = color_to_rgb(low_color)
142+
high_rgb = color_to_rgb(high_color)
143+
144+
# Linear interpolation
145+
r = int(low_rgb[0] + t * (high_rgb[0] - low_rgb[0]))
146+
g = int(low_rgb[1] + t * (high_rgb[1] - low_rgb[1]))
147+
b = int(low_rgb[2] + t * (high_rgb[2] - low_rgb[2]))
148+
149+
return f'rgb({r}, {g}, {b})'

ggplotly/trace_builders.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,13 +415,30 @@ class ContinuousColorTraceBuilder(TraceBuilder):
415415
with a colorscale instead of separate traces per category. Plotly
416416
handles the color interpolation via marker.color and marker.colorscale.
417417
418+
For line-based traces (mode='lines'), we use a segment-based approach
419+
since Plotly's line.color only accepts single values, not arrays.
420+
418421
Example:
419422
ggplot(df, aes(x='x', y='y', color='temperature')) + geom_point()
420423
# Single trace with colorscale from low to high temperature
421424
"""
422425

426+
# Default Viridis colorscale endpoints
427+
DEFAULT_COLORSCALE = [[0, '#440154'], [1, '#fde725']]
428+
423429
def build(self, apply_color_targets_fn):
424-
"""Build a single trace with colorscale for continuous color."""
430+
"""Build trace(s) with colorscale for continuous color."""
431+
# Check if this is a line-based trace
432+
# For lines, we need segment-based rendering since line.color
433+
# only accepts a single value, not an array
434+
if self.payload.get('mode') == 'lines':
435+
return self._build_line_gradient()
436+
437+
# Original marker-based approach for scatter, bar, etc.
438+
return self._build_marker_gradient()
439+
440+
def _build_marker_gradient(self):
441+
"""Build a single trace with marker colorscale (original approach)."""
425442
style_props = self.style_props
426443

427444
# Get the numeric color values
@@ -463,6 +480,110 @@ def build(self, apply_color_targets_fn):
463480
col=self.col,
464481
)
465482

483+
def _build_line_gradient(self):
484+
"""
485+
Build gradient line using individual colored segments.
486+
487+
Since Plotly's line.color only accepts a single value (not an array),
488+
we draw each segment as a separate Scattergl trace with its own color.
489+
Uses WebGL for efficient rendering of many traces.
490+
"""
491+
import plotly.graph_objects as go
492+
493+
style_props = self.style_props
494+
495+
# Get the numeric color values
496+
if style_props.get('color_is_continuous'):
497+
color_values = style_props['color_series']
498+
else:
499+
color_values = style_props['fill_series']
500+
501+
# Get line width from style_props or params
502+
line_width = style_props.get('size', 2)
503+
if line_width is None:
504+
line_width = self.params.get('size', 2)
505+
506+
# Extract arrays
507+
x_vals = self.x.values if hasattr(self.x, 'values') else list(self.x)
508+
y_vals = self.y.values if hasattr(self.y, 'values') else list(self.y)
509+
c_vals = color_values.values if hasattr(color_values, 'values') else list(color_values)
510+
511+
vmin, vmax = min(c_vals), max(c_vals)
512+
colorscale = self.DEFAULT_COLORSCALE
513+
514+
# Draw each segment with interpolated color
515+
for i in range(len(x_vals) - 1):
516+
# Normalize color value at midpoint of segment
517+
t_norm = ((c_vals[i] + c_vals[i + 1]) / 2 - vmin) / (vmax - vmin) if vmax != vmin else 0
518+
color = self._interpolate_color(colorscale, t_norm)
519+
520+
self.fig.add_trace(
521+
go.Scattergl( # WebGL for performance with many traces
522+
x=[x_vals[i], x_vals[i + 1]],
523+
y=[y_vals[i], y_vals[i + 1]],
524+
mode='lines',
525+
line=dict(color=color, width=line_width),
526+
opacity=self.alpha,
527+
showlegend=False,
528+
hoverinfo='skip',
529+
# Tag for scale_color_gradient to update colors
530+
meta={'_ggplotly_line_gradient': True, '_color_norm': t_norm}
531+
),
532+
row=self.row,
533+
col=self.col,
534+
)
535+
536+
# Add invisible trace for colorbar
537+
self.fig.add_trace(
538+
go.Scatter(
539+
x=[None],
540+
y=[None],
541+
mode='markers',
542+
marker=dict(
543+
color=[vmin, vmax],
544+
colorscale=colorscale,
545+
showscale=True,
546+
colorbar=dict(title=self.mapping.get('color', ''))
547+
),
548+
showlegend=False,
549+
hoverinfo='skip'
550+
),
551+
row=self.row,
552+
col=self.col,
553+
)
554+
555+
@staticmethod
556+
def _interpolate_color(colorscale, t):
557+
"""
558+
Interpolate between colorscale endpoints.
559+
560+
Parameters:
561+
colorscale: List of [position, color] pairs (e.g., [[0, '#440154'], [1, '#fde725']])
562+
t: Normalized value between 0 and 1
563+
564+
Returns:
565+
str: Interpolated RGB color string
566+
"""
567+
t = max(0, min(1, t)) # Clamp to [0, 1]
568+
569+
low_color = colorscale[0][1]
570+
high_color = colorscale[1][1]
571+
572+
# Parse hex colors to RGB
573+
def hex_to_rgb(hex_color):
574+
hex_color = hex_color.lstrip('#')
575+
return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
576+
577+
low_rgb = hex_to_rgb(low_color)
578+
high_rgb = hex_to_rgb(high_color)
579+
580+
# Linear interpolation
581+
r = int(low_rgb[0] + t * (high_rgb[0] - low_rgb[0]))
582+
g = int(low_rgb[1] + t * (high_rgb[1] - low_rgb[1]))
583+
b = int(low_rgb[2] + t * (high_rgb[2] - low_rgb[2]))
584+
585+
return f'rgb({r}, {g}, {b})'
586+
466587

467588
class SingleTraceBuilder(TraceBuilder):
468589
"""

0 commit comments

Comments
 (0)