33from matplotlib .colors import Normalize
44import numpy as np
55import plotly .express as px
6+ import plotly .graph_objects as go
67import pandas as pd
78from sklearn .cluster import KMeans
9+ from itertools import combinations
810
911
1012def hill_slopes (rule , transactions ):
@@ -554,4 +556,115 @@ def prepare_data(rules, metrics):
554556 plt .legend (title = "Order" )
555557 plt .grid (True )
556558 return plt
557-
559+
560+
561+ def sankey_diagram (rules , interestingness_measure , M = 4 ):
562+ """
563+ Visualize rules as a sankey diagram.
564+
565+ Args:
566+ rules (Rule): Association rule or rules to visualize.
567+ interestingness_measures (str): Interestingness measure Z = {supp, cons, lift},reflecting the quality of a particular connection.
568+ m (int): Maximum number of rules to be selected for visualization. Default: 4
569+
570+ Returns:
571+ Figure or plot.
572+ """
573+
574+
575+ def compute_similarity (rule1 , rule2 ):
576+ """Compute similarity between two rules."""
577+ ant_inter = len (set (str (rule1 .antecedent )) & set (str (rule2 .antecedent )))
578+ ant_union = len (set (str (rule1 .antecedent )) | set (str (rule2 .antecedent )))
579+ con_inter = len (set (str (rule1 .consequent )) & set (str (rule2 .consequent )))
580+ con_union = len (set (str (rule1 .consequent )) | set (str (rule2 .consequent )))
581+ return (ant_inter + con_inter ) / (ant_union + con_union )
582+
583+ def build_adjacency_matrix (rules ):
584+ size = len (rules )
585+ adjacency_matrix = np .zeros ((size , size ))
586+
587+ for i , j in combinations (range (size ), 2 ):
588+ similarity = compute_similarity (rules [i ], rules [j ])
589+ adjacency_matrix [i , j ] = similarity
590+ adjacency_matrix [j , i ] = similarity
591+
592+ return adjacency_matrix
593+
594+ def knapsack_selection (adj_matrix , rules , M ):
595+ fitness_scores = np .array ([rule .fitness for rule in rules ])
596+ N = len (rules )
597+ weights = np .ones (N )
598+ similarity_weight = 1.0
599+ fitness_weight = 0.5
600+ combined_profits = similarity_weight * np .sum (adj_matrix ) + fitness_weight * fitness_scores
601+
602+ selected = np .zeros (N , dtype = int )
603+
604+ # Initialize DP table
605+ dp = np .zeros ((N + 1 , M + 1 ))
606+ for i in range (1 , N + 1 ):
607+ for w in range (1 , M + 1 ):
608+ if weights [i - 1 ] <= w :
609+ dp [i , w ] = max (dp [i - 1 , w ], dp [i - 1 , w - 1 ] + combined_profits [i - 1 ])
610+ else :
611+ dp [i , w ] = dp [i - 1 , w ]
612+
613+ # Backtrack to find selected rules
614+ w = M
615+ for i in range (N , 0 , - 1 ):
616+ if dp [i , w ] != dp [i - 1 , w ]:
617+ selected [i - 1 ] = 1
618+ w -= 1
619+
620+ selected_rules = [rules [i ] for i in range (N ) if selected [i ]]
621+
622+ return selected_rules
623+
624+ def prepare_data (rules , M , interestingness_measure ):
625+ adj_matrix = build_adjacency_matrix (rules )
626+ selected_rules = knapsack_selection (adj_matrix , rules , M )
627+
628+ sources = []
629+ targets = []
630+ values = []
631+ labels = []
632+ node_indices = {}
633+
634+ for rule in selected_rules :
635+ for antecedent in rule .antecedent :
636+ if str (antecedent ) not in node_indices :
637+ node_indices [str (antecedent )] = len (labels )
638+ labels .append (str (antecedent ))
639+ sources .append (node_indices [str (antecedent )])
640+
641+ for consequent in rule .consequent :
642+ if str (consequent ) not in node_indices :
643+ node_indices [str (consequent )] = len (labels )
644+ labels .append (str (consequent ))
645+ targets .append (node_indices [str (consequent )])
646+
647+ measure_value = getattr (rule , interestingness_measure , rule .support ) #default support
648+ values .append (measure_value )
649+
650+ return labels , sources , targets , values
651+
652+ labels , sources , targets , values = prepare_data (rules , M , interestingness_measure )
653+
654+ fig = go .Figure (go .Sankey (
655+ node = dict (
656+ pad = 15 ,
657+ thickness = 20 ,
658+ line = dict (color = 'black' , width = 0.5 ),
659+ label = labels
660+ ),
661+ link = dict (
662+ source = sources ,
663+ target = targets ,
664+ value = values
665+ )
666+ ))
667+ fig .update_layout (title_text = f'Sankey Diagram of Association Rules ({ interestingness_measure } )' , font_size = 10 )
668+
669+ return fig
670+
0 commit comments