@@ -29,7 +29,7 @@ def __init__(
2929 self .parameters = parameters
3030 self .score_cache = {}
3131
32- if self .local_score_fun == local_score_BIC_from_cov :
32+ if self .local_score_fun . __name__ == ' local_score_BIC_from_cov' :
3333 self .cov = np .cov (self .data .T )
3434 self .n = self .data .shape [0 ]
3535
@@ -40,15 +40,15 @@ def score(self, i: int, PAi: List[int]) -> float:
4040 hash_key = tuple (sorted (PAi ))
4141
4242 if not self .score_cache [i ].__contains__ (hash_key ):
43- if self .local_score_fun == local_score_BIC_from_cov :
43+ if self .local_score_fun . __name__ == ' local_score_BIC_from_cov' :
4444 self .score_cache [i ][hash_key ] = self .local_score_fun ((self .cov , self .n ), i , PAi , self .parameters )
4545 else :
4646 self .score_cache [i ][hash_key ] = self .local_score_fun (self .data , i , PAi , self .parameters )
4747
4848 return self .score_cache [i ][hash_key ]
4949
5050 def score_nocache (self , i : int , PAi : List [int ]) -> float :
51- if self .local_score_fun == local_score_BIC_from_cov :
51+ if self .local_score_fun . __name__ == ' local_score_BIC_from_cov' :
5252 return self .local_score_fun ((self .cov , self .n ), i , PAi , self .parameters )
5353 else :
54- return self .local_score_fun (self .data , i , PAi , self .parameters )
54+ return self .local_score_fun (self .data , i , PAi , self .parameters )
0 commit comments