From ec493405d47991c97a62df6f43f132e9d98c81bd Mon Sep 17 00:00:00 2001 From: Yujia Zheng Date: Fri, 10 Jan 2025 20:58:52 -0500 Subject: [PATCH 1/2] Update LocalScoreFunctionClass.py --- causallearn/score/LocalScoreFunctionClass.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/causallearn/score/LocalScoreFunctionClass.py b/causallearn/score/LocalScoreFunctionClass.py index 8171d515..64b2081f 100644 --- a/causallearn/score/LocalScoreFunctionClass.py +++ b/causallearn/score/LocalScoreFunctionClass.py @@ -29,7 +29,7 @@ def __init__( self.parameters = parameters self.score_cache = {} - if self.local_score_fun == local_score_BIC_from_cov: + if self.local_score_fun.__name__ == 'local_score_BIC_from_cov': self.cov = np.cov(self.data.T) self.n = self.data.shape[0] @@ -40,7 +40,7 @@ def score(self, i: int, PAi: List[int]) -> float: hash_key = tuple(sorted(PAi)) if not self.score_cache[i].__contains__(hash_key): - if self.local_score_fun == local_score_BIC_from_cov: + if self.local_score_fun.__name__ == 'local_score_BIC_from_cov': self.score_cache[i][hash_key] = self.local_score_fun((self.cov, self.n), i, PAi, self.parameters) else: self.score_cache[i][hash_key] = self.local_score_fun(self.data, i, PAi, self.parameters) @@ -48,7 +48,7 @@ def score(self, i: int, PAi: List[int]) -> float: return self.score_cache[i][hash_key] def score_nocache(self, i: int, PAi: List[int]) -> float: - if self.local_score_fun == local_score_BIC_from_cov: + if self.local_score_fun.__name__ == 'local_score_BIC_from_cov': return self.local_score_fun((self.cov, self.n), i, PAi, self.parameters) else: - return self.local_score_fun(self.data, i, PAi, self.parameters) \ No newline at end of file + return self.local_score_fun(self.data, i, PAi, self.parameters) From bc9b002bd3a3de731826908318f21f803ccc4165 Mon Sep 17 00:00:00 2001 From: Yujia Zheng Date: Fri, 10 Jan 2025 20:59:32 -0500 Subject: [PATCH 2/2] Update BOSS.py --- causallearn/search/PermutationBased/BOSS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causallearn/search/PermutationBased/BOSS.py b/causallearn/search/PermutationBased/BOSS.py index cf18d43f..85551b8e 100644 --- a/causallearn/search/PermutationBased/BOSS.py +++ b/causallearn/search/PermutationBased/BOSS.py @@ -23,7 +23,7 @@ def boss( X: np.ndarray, - score_func: str = "local_score_BIC", + score_func: str = "local_score_BIC_from_cov", parameters: Optional[Dict[str, Any]] = None, verbose: Optional[bool] = True, node_names: Optional[List[str]] = None,