@@ -334,12 +334,9 @@ def plot_affiliation_shape(
334334 :param path: The file path to save the plot. If empty and save is True, a default path will be used.
335335 """
336336 try :
337- import pandas as pd
338337 import matplotlib .pyplot as plt
339338 except ImportError :
340- raise ImportError (
341- "pandas and matplotlib are required for plotting. Please install them."
342- )
339+ raise ImportError ("matplotlib is required for plotting. Please install it." )
343340
344341 all_affiliations = self .get_unique_term_affiliations ()
345342 if affiliation not in all_affiliations :
@@ -362,43 +359,55 @@ def plot_affiliation_shape(
362359
363360 predictor_names = affiliation .split (" & " )
364361
365- shape_df = pd .DataFrame (shape , columns = predictor_names + ["contribution" ])
366-
367362 is_main_effect : bool = len (predictor_indexes_used ) == 1
368363 is_two_way_interaction : bool = len (predictor_indexes_used ) == 2
369364
370365 if is_main_effect :
371366 fig = plt .figure ()
372- plt .plot (shape_df .iloc [:, 0 ], shape_df .iloc [:, 1 ])
373- plt .xlabel (shape_df .columns [0 ])
367+ # Sort by predictor value for a clean line plot
368+ sorted_indices = np .argsort (shape [:, 0 ])
369+ plt .plot (shape [sorted_indices , 0 ], shape [sorted_indices , 1 ])
370+ plt .xlabel (predictor_names [0 ])
374371 plt .ylabel ("Contribution to linear predictor" )
375- plt .title (f"Main effect of { shape_df . columns [0 ]} " )
372+ plt .title (f"Main effect of { predictor_names [0 ]} " )
376373 plt .grid (True )
377374 elif is_two_way_interaction :
378375 fig = plt .figure (figsize = (8 , 6 ))
379- pivot_table = shape_df .pivot_table (
380- index = shape_df .columns [0 ],
381- columns = shape_df .columns [1 ],
382- values = shape_df .columns [2 ],
383- aggfunc = "mean" ,
384- )
376+
377+ # Get unique coordinates and their inverse mapping
378+ y_unique , y_inv = np .unique (shape [:, 0 ], return_inverse = True )
379+ x_unique , x_inv = np .unique (shape [:, 1 ], return_inverse = True )
380+
381+ # Create grid for sums and counts
382+ grid_sums = np .zeros ((len (y_unique ), len (x_unique )))
383+ grid_counts = np .zeros ((len (y_unique ), len (x_unique )))
384+
385+ # Populate sums and counts to later calculate the mean
386+ np .add .at (grid_sums , (y_inv , x_inv ), shape [:, 2 ])
387+ np .add .at (grid_counts , (y_inv , x_inv ), 1 )
388+
389+ # Calculate mean, avoiding division by zero
390+ with np .errstate (divide = "ignore" , invalid = "ignore" ):
391+ pivot_table_values = np .true_divide (grid_sums , grid_counts )
392+ # Where there's no data, pivot_table_values will be nan, which is fine for imshow.
393+
385394 plt .imshow (
386- pivot_table . values ,
395+ pivot_table_values ,
387396 aspect = "auto" ,
388397 origin = "lower" ,
389398 extent = [
390- pivot_table . columns .min (),
391- pivot_table . columns .max (),
392- pivot_table . index .min (),
393- pivot_table . index .max (),
399+ x_unique .min (),
400+ x_unique .max (),
401+ y_unique .min (),
402+ y_unique .max (),
394403 ],
395404 cmap = "Blues_r" ,
396405 )
397406 plt .colorbar (label = "Contribution to the linear predictor" )
398- plt .xlabel (shape_df . columns [1 ])
399- plt .ylabel (shape_df . columns [0 ])
407+ plt .xlabel (predictor_names [1 ])
408+ plt .ylabel (predictor_names [0 ])
400409 plt .title (
401- f"Interaction between { shape_df . columns [0 ]} and { shape_df . columns [1 ]} "
410+ f"Interaction between { predictor_names [0 ]} and { predictor_names [1 ]} "
402411 )
403412 else :
404413 print (
@@ -407,9 +416,7 @@ def plot_affiliation_shape(
407416 return
408417
409418 if save :
410- save_path = (
411- path if path else f"shape_of_{ affiliation .replace (' & ' , '_' )} .png"
412- )
419+ save_path = path or f"shape_of_{ affiliation .replace (' & ' , '_' )} .png"
413420 plt .savefig (save_path )
414421
415422 if plot :
0 commit comments