@@ -387,7 +387,24 @@ def _scatter_plot(df: pd.DataFrame) -> str:
387387 return fig .to_html ()
388388 return _scatter_plot
389389
390- def bar_plot (x : str = 'id' , y : Union [str , List [str ]] = None ) -> Callable [[pd .DataFrame ], str ]:
390+ def bar_plot (x : str = 'id' ,
391+ y : Union [str , List [str ]] = None ,
392+ stacked : bool = False ,
393+ color_palette : str = 'Plotly' ) -> Callable [[pd .DataFrame ], str ]:
394+ """
395+ Creates a bar plot visualization function.
396+
397+ Args:
398+ x (str): The column name to use for the x-axis. Defaults to 'id'.
399+ y (Union[str, List[str]]): The column name(s) to plot on the y-axis.
400+ If None, uses ['prediction', 'expected']. Defaults to None.
401+ stacked (bool): Whether to stack the bars or group them. Defaults to False.
402+ color_palette (str): The color palette to use. Options: 'Plotly', 'G10',
403+ 'Dark24', 'Light24', 'Pastel', etc. Defaults to 'Plotly'.
404+
405+ Returns:
406+ Callable[[pd.DataFrame], str]: A function that takes a DataFrame and returns HTML.
407+ """
391408 if y is None :
392409 y = ['prediction' , 'expected' ]
393410 if isinstance (y , str ):
@@ -396,64 +413,84 @@ def bar_plot(x: str = 'id', y: Union[str, List[str]] = None) -> Callable[[pd.Dat
396413 def _bar_plot (df : pd .DataFrame ) -> str :
397414 try :
398415 import plotly .graph_objects as go
416+ import plotly .express as px
399417 except ImportError :
400418 return "Plotly is not installed. Please install it using 'pip install plotly'"
401419
420+ # Get the appropriate color palette
421+ try :
422+ color_seq = getattr (px .colors .qualitative , color_palette )
423+ except AttributeError :
424+ # Fall back to default Plotly colors if the palette doesn't exist
425+ color_seq = px .colors .qualitative .Plotly
426+
402427 fig = go .Figure ()
403- colors = ['blue' , 'orange' , 'green' , 'red' , 'purple' , 'brown' , 'pink' , 'gray' , 'cyan' , 'magenta' ]
404428
405429 for i , y_col in enumerate (y ):
406- fig .add_trace (go .Bar (name = y_col , x = df [x ], y = df [y_col ], marker_color = colors [i % len (colors )]))
430+ fig .add_trace (go .Bar (
431+ name = y_col ,
432+ x = df [x ],
433+ y = df [y_col ],
434+ marker_color = color_seq [i % len (color_seq )]
435+ ))
436+
437+ # Set barmode based on stacked parameter
438+ barmode = 'stack' if stacked else 'group'
439+ fig .update_layout (
440+ barmode = barmode ,
441+ xaxis_title = x ,
442+ yaxis_title = 'Value' ,
443+ legend_title_text = 'Metrics'
444+ )
407445
408- fig .update_layout (barmode = 'group' , xaxis_title = x , yaxis_title = 'Value' )
409446 return fig .to_html ()
410447
411448 return _bar_plot
412449
413- # # Usage example
414- # test_valid = Validation()
415-
416- # # Create a group without default error terms
417- # group = test_valid.create_group("test", add_default_error_terms=False)
418-
419- # # Add sample predictions to the group
420- # # use arbitrary "example1" and "example2" columns for grouping
421- # group.add_item('s1', 1, 1, example1='a' , example2=1)
422- # group.add_item('s2', 2, 2.5, example1='a' , example2=2)
423- # group.add_item('s3', 3, 4, example1='b')
424- # group.add_item('s4', 4, 6, example1='c' , example2=2)
425- # group.add_item('s5', 5, 7, example1='b' , example2=1)
426-
427- # # Add error terms
428- # group.add_error_terms({'squared_error': squared_error})
429-
430- # # Add aggregations to the group
431- # # Mean absolute error for all items
432- # group.add_aggregation('mae', mae)
433- # # Maximum error for all items, grouped by test_val
434- # group.add_aggregation('max', max_error, group_by='example2')
435-
436- # # mean squared error for predictions >= 3, grouped by test1
437- # group.add_aggregation("mse_error", mse, filter='prediction>=3', group_by='example1')
438- # # Same as above but using a precalculated error term
439- # group.add_aggregation('mse_error2', mean('squared_error'), filter='prediction>=3', group_by='example1')
440-
441- # group2 = test_valid.create_group("test2")
442- # # Add larget dataset of random points
443- # for i in range(1, 1000):
444- # group2.add_item(f'point{i}', i**1.02 + random.random()*200-150, i, even='even' if i % 2 == 0 else 'odd')
445- # group2.add_aggregation('mse', mean('squared_error'), group_by='even')
446-
447- # group.add_visualization('test bar plot', bar_plot(y=['prediction', 'expected', 'squared_error']))
448- # group2.add_visualization('test scatter plot', scatter_plot(color='absolute_error'))
449-
450- # # Run all aggregations and print the results
451- # test_valid.to_html('test_validation.html')
452- # # Save the validation object to a file
453- # test_valid.save_to_file('test_validation.pklz')
454- # # Load the validation object from a file
455- # loaded_valid = Validation.load_from_file('test_validation.pklz')
456- # loaded_valid.to_html('test_validation_loaded.html')
457- # # Open the generated HTML file in the default web browser
458- # import webbrowser
459- # webbrowser.open('test_validation.html')
450+ # Usage example
451+ test_valid = Validation ()
452+
453+ # Create a group without default error terms
454+ group = test_valid .create_group ("test" , add_default_error_terms = False )
455+
456+ # Add sample predictions to the group
457+ # use arbitrary "example1" and "example2" columns for grouping
458+ group .add_item ('s1' , 1 , 1 , example1 = 'a' , example2 = 1 )
459+ group .add_item ('s2' , 2 , 2.5 , example1 = 'a' , example2 = 2 )
460+ group .add_item ('s3' , 3 , 4 , example1 = 'b' )
461+ group .add_item ('s4' , 4 , 6 , example1 = 'c' , example2 = 2 )
462+ group .add_item ('s5' , 5 , 7 , example1 = 'b' , example2 = 1 )
463+
464+ # Add error terms
465+ group .add_error_terms ({'squared_error' : squared_error })
466+
467+ # Add aggregations to the group
468+ # Mean absolute error for all items
469+ group .add_aggregation ('mae' , mae )
470+ # Maximum error for all items, grouped by test_val
471+ group .add_aggregation ('max' , max_error , group_by = 'example2' )
472+
473+ # mean squared error for predictions >= 3, grouped by test1
474+ group .add_aggregation ("mse_error" , mse , filter = 'prediction>=3' , group_by = 'example1' )
475+ # Same as above but using a precalculated error term
476+ group .add_aggregation ('mse_error2' , mean ('squared_error' ), filter = 'prediction>=3' , group_by = 'example1' )
477+
478+ group2 = test_valid .create_group ("test2" )
479+ # Add larget dataset of random points
480+ for i in range (1 , 1000 ):
481+ group2 .add_item (f'point{ i } ' , i ** 1.02 + random .random ()* 200 - 150 , i , even = 'even' if i % 2 == 0 else 'odd' )
482+ group2 .add_aggregation ('mse' , mean ('squared_error' ), group_by = 'even' )
483+
484+ group .add_visualization ('test bar plot' , bar_plot (y = ['prediction' , 'expected' , 'squared_error' ]))
485+ group2 .add_visualization ('test scatter plot' , scatter_plot (color = 'absolute_error' ))
486+
487+ # Run all aggregations and print the results
488+ test_valid .to_html ('test_validation.html' )
489+ # Save the validation object to a file
490+ test_valid .save_to_file ('test_validation.pklz' )
491+ # Load the validation object from a file
492+ loaded_valid = Validation .load_from_file ('test_validation.pklz' )
493+ loaded_valid .to_html ('test_validation_loaded.html' )
494+ # Open the generated HTML file in the default web browser
495+ import webbrowser
496+ webbrowser .open ('test_validation.html' )
0 commit comments