Skip to content

Commit 8c03d10

Browse files
committed
Enhance bar_plot function to support additional parameters and color palettes for improved visualization
1 parent ff9bb5e commit 8c03d10

File tree

1 file changed

+88
-51
lines changed

1 file changed

+88
-51
lines changed

Scripts/utils/validation.py

Lines changed: 88 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,24 @@ def _scatter_plot(df: pd.DataFrame) -> str:
387387
return fig.to_html()
388388
return _scatter_plot
389389

390-
def bar_plot(x: str = 'id', y: Union[str, List[str]] = None) -> Callable[[pd.DataFrame], str]:
390+
def bar_plot(x: str = 'id',
391+
y: Union[str, List[str]] = None,
392+
stacked: bool = False,
393+
color_palette: str = 'Plotly') -> Callable[[pd.DataFrame], str]:
394+
"""
395+
Creates a bar plot visualization function.
396+
397+
Args:
398+
x (str): The column name to use for the x-axis. Defaults to 'id'.
399+
y (Union[str, List[str]]): The column name(s) to plot on the y-axis.
400+
If None, uses ['prediction', 'expected']. Defaults to None.
401+
stacked (bool): Whether to stack the bars or group them. Defaults to False.
402+
color_palette (str): The color palette to use. Options: 'Plotly', 'G10',
403+
'Dark24', 'Light24', 'Pastel', etc. Defaults to 'Plotly'.
404+
405+
Returns:
406+
Callable[[pd.DataFrame], str]: A function that takes a DataFrame and returns HTML.
407+
"""
391408
if y is None:
392409
y = ['prediction', 'expected']
393410
if isinstance(y, str):
@@ -396,64 +413,84 @@ def bar_plot(x: str = 'id', y: Union[str, List[str]] = None) -> Callable[[pd.Dat
396413
def _bar_plot(df: pd.DataFrame) -> str:
397414
try:
398415
import plotly.graph_objects as go
416+
import plotly.express as px
399417
except ImportError:
400418
return "Plotly is not installed. Please install it using 'pip install plotly'"
401419

420+
# Get the appropriate color palette
421+
try:
422+
color_seq = getattr(px.colors.qualitative, color_palette)
423+
except AttributeError:
424+
# Fall back to default Plotly colors if the palette doesn't exist
425+
color_seq = px.colors.qualitative.Plotly
426+
402427
fig = go.Figure()
403-
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'cyan', 'magenta']
404428

405429
for i, y_col in enumerate(y):
406-
fig.add_trace(go.Bar(name=y_col, x=df[x], y=df[y_col], marker_color=colors[i % len(colors)]))
430+
fig.add_trace(go.Bar(
431+
name=y_col,
432+
x=df[x],
433+
y=df[y_col],
434+
marker_color=color_seq[i % len(color_seq)]
435+
))
436+
437+
# Set barmode based on stacked parameter
438+
barmode = 'stack' if stacked else 'group'
439+
fig.update_layout(
440+
barmode=barmode,
441+
xaxis_title=x,
442+
yaxis_title='Value',
443+
legend_title_text='Metrics'
444+
)
407445

408-
fig.update_layout(barmode='group', xaxis_title=x, yaxis_title='Value')
409446
return fig.to_html()
410447

411448
return _bar_plot
412449

413-
# # Usage example
414-
# test_valid = Validation()
415-
416-
# # Create a group without default error terms
417-
# group = test_valid.create_group("test", add_default_error_terms=False)
418-
419-
# # Add sample predictions to the group
420-
# # use arbitrary "example1" and "example2" columns for grouping
421-
# group.add_item('s1', 1, 1, example1='a' , example2=1)
422-
# group.add_item('s2', 2, 2.5, example1='a' , example2=2)
423-
# group.add_item('s3', 3, 4, example1='b')
424-
# group.add_item('s4', 4, 6, example1='c' , example2=2)
425-
# group.add_item('s5', 5, 7, example1='b' , example2=1)
426-
427-
# # Add error terms
428-
# group.add_error_terms({'squared_error': squared_error})
429-
430-
# # Add aggregations to the group
431-
# # Mean absolute error for all items
432-
# group.add_aggregation('mae', mae)
433-
# # Maximum error for all items, grouped by test_val
434-
# group.add_aggregation('max', max_error, group_by='example2')
435-
436-
# # mean squared error for predictions >= 3, grouped by test1
437-
# group.add_aggregation("mse_error", mse, filter='prediction>=3', group_by='example1')
438-
# # Same as above but using a precalculated error term
439-
# group.add_aggregation('mse_error2', mean('squared_error'), filter='prediction>=3', group_by='example1')
440-
441-
# group2 = test_valid.create_group("test2")
442-
# # Add larget dataset of random points
443-
# for i in range(1, 1000):
444-
# group2.add_item(f'point{i}', i**1.02 + random.random()*200-150, i, even='even' if i % 2 == 0 else 'odd')
445-
# group2.add_aggregation('mse', mean('squared_error'), group_by='even')
446-
447-
# group.add_visualization('test bar plot', bar_plot(y=['prediction', 'expected', 'squared_error']))
448-
# group2.add_visualization('test scatter plot', scatter_plot(color='absolute_error'))
449-
450-
# # Run all aggregations and print the results
451-
# test_valid.to_html('test_validation.html')
452-
# # Save the validation object to a file
453-
# test_valid.save_to_file('test_validation.pklz')
454-
# # Load the validation object from a file
455-
# loaded_valid = Validation.load_from_file('test_validation.pklz')
456-
# loaded_valid.to_html('test_validation_loaded.html')
457-
# # Open the generated HTML file in the default web browser
458-
# import webbrowser
459-
# webbrowser.open('test_validation.html')
450+
# Usage example
451+
test_valid = Validation()
452+
453+
# Create a group without default error terms
454+
group = test_valid.create_group("test", add_default_error_terms=False)
455+
456+
# Add sample predictions to the group
457+
# use arbitrary "example1" and "example2" columns for grouping
458+
group.add_item('s1', 1, 1, example1='a' , example2=1)
459+
group.add_item('s2', 2, 2.5, example1='a' , example2=2)
460+
group.add_item('s3', 3, 4, example1='b')
461+
group.add_item('s4', 4, 6, example1='c' , example2=2)
462+
group.add_item('s5', 5, 7, example1='b' , example2=1)
463+
464+
# Add error terms
465+
group.add_error_terms({'squared_error': squared_error})
466+
467+
# Add aggregations to the group
468+
# Mean absolute error for all items
469+
group.add_aggregation('mae', mae)
470+
# Maximum error for all items, grouped by test_val
471+
group.add_aggregation('max', max_error, group_by='example2')
472+
473+
# mean squared error for predictions >= 3, grouped by test1
474+
group.add_aggregation("mse_error", mse, filter='prediction>=3', group_by='example1')
475+
# Same as above but using a precalculated error term
476+
group.add_aggregation('mse_error2', mean('squared_error'), filter='prediction>=3', group_by='example1')
477+
478+
group2 = test_valid.create_group("test2")
479+
# Add larget dataset of random points
480+
for i in range(1, 1000):
481+
group2.add_item(f'point{i}', i**1.02 + random.random()*200-150, i, even='even' if i % 2 == 0 else 'odd')
482+
group2.add_aggregation('mse', mean('squared_error'), group_by='even')
483+
484+
group.add_visualization('test bar plot', bar_plot(y=['prediction', 'expected', 'squared_error']))
485+
group2.add_visualization('test scatter plot', scatter_plot(color='absolute_error'))
486+
487+
# Run all aggregations and print the results
488+
test_valid.to_html('test_validation.html')
489+
# Save the validation object to a file
490+
test_valid.save_to_file('test_validation.pklz')
491+
# Load the validation object from a file
492+
loaded_valid = Validation.load_from_file('test_validation.pklz')
493+
loaded_valid.to_html('test_validation_loaded.html')
494+
# Open the generated HTML file in the default web browser
495+
import webbrowser
496+
webbrowser.open('test_validation.html')

0 commit comments

Comments
 (0)