55from scipy .stats import chi2 , norm
66
77from causallearn .utils .KCI .KCI import KCI_CInd , KCI_UInd
8+ from causallearn .utils .FastKCI .FastKCI import FastKCI_CInd , FastKCI_UInd
9+ from causallearn .utils .RCIT .RCIT import RCIT as RCIT_CInd
10+ from causallearn .utils .RCIT .RCIT import RIT as RCIT_UInd
811from causallearn .utils .PCUtils import Helper
912
1013CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5
1316mv_fisherz = "mv_fisherz"
1417mc_fisherz = "mc_fisherz"
1518kci = "kci"
19+ rcit = "rcit"
20+ fastkci = "fastkci"
1621chisq = "chisq"
1722gsq = "gsq"
1823d_separation = "d_separation"
@@ -23,15 +28,19 @@ def CIT(data, method='fisherz', **kwargs):
2328 Parameters
2429 ----------
2530 data: numpy.ndarray of shape (n_samples, n_features)
26- method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "chisq", "gsq"]
27- kwargs: placeholder for future arguments, or for KCI specific arguments now
31+ method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "rcit", "fastkci", " chisq", "gsq"]
32+ kwargs: placeholder for future arguments, or for KCI, FastKCI or RCIT specific arguments now
2833 TODO: utimately kwargs should be replaced by explicit named parameters.
2934 check https://github.com/cmu-phil/causal-learn/pull/62#discussion_r927239028
3035 '''
3136 if method == fisherz :
3237 return FisherZ (data , ** kwargs )
3338 elif method == kci :
3439 return KCI (data , ** kwargs )
40+ elif method == fastkci :
41+ return FastKCI (data , ** kwargs )
42+ elif method == rcit :
43+ return RCIT (data , ** kwargs )
3544 elif method in [chisq , gsq ]:
3645 return Chisq_or_Gsq (data , method_name = method , ** kwargs )
3746 elif method == mv_fisherz :
@@ -43,6 +52,7 @@ def CIT(data, method='fisherz', **kwargs):
4352 else :
4453 raise ValueError ("Unknown method: {}" .format (method ))
4554
55+
4656class CIT_Base (object ):
4757 # Base class for CIT, contains basic operations for input check and caching, etc.
4858 def __init__ (self , data , cache_path = None , ** kwargs ):
@@ -193,6 +203,50 @@ def __call__(self, X, Y, condition_set=None):
193203 self .pvalue_cache [cache_key ] = p
194204 return p
195205
206+ class FastKCI (CIT_Base ):
207+ def __init__ (self , data , ** kwargs ):
208+ super ().__init__ (data , ** kwargs )
209+ kci_ui_kwargs = {k : v for k , v in kwargs .items () if k in
210+ ['K' , 'J' , 'alpha' ]}
211+ kci_ci_kwargs = {k : v for k , v in kwargs .items () if k in
212+ ['K' , 'J' , 'alpha' , 'use_gp' ]}
213+ self .check_cache_method_consistent (
214+ 'kci' , hashlib .md5 (json .dumps (kci_ci_kwargs , sort_keys = True ).encode ('utf-8' )).hexdigest ())
215+ self .assert_input_data_is_valid ()
216+ self .kci_ui = FastKCI_UInd (** kci_ui_kwargs )
217+ self .kci_ci = FastKCI_CInd (** kci_ci_kwargs )
218+
219+ def __call__ (self , X , Y , condition_set = None ):
220+ # Kernel-based conditional independence test.
221+ Xs , Ys , condition_set , cache_key = self .get_formatted_XYZ_and_cachekey (X , Y , condition_set )
222+ if cache_key in self .pvalue_cache : return self .pvalue_cache [cache_key ]
223+ p = self .kci_ui .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ])[0 ] if len (condition_set ) == 0 else \
224+ self .kci_ci .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ], self .data [:, condition_set ])[0 ]
225+ self .pvalue_cache [cache_key ] = p
226+ return p
227+
228+ class RCIT (CIT_Base ):
229+ def __init__ (self , data , ** kwargs ):
230+ super ().__init__ (data , ** kwargs )
231+ rit_kwargs = {k : v for k , v in kwargs .items () if k in
232+ ['approx' ]}
233+ rcit_kwargs = {k : v for k , v in kwargs .items () if k in
234+ ['approx' , 'num_f' , 'num_f2' , 'rcit' ]}
235+ self .check_cache_method_consistent (
236+ 'kci' , hashlib .md5 (json .dumps (rcit_kwargs , sort_keys = True ).encode ('utf-8' )).hexdigest ())
237+ self .assert_input_data_is_valid ()
238+ self .rit = RCIT_UInd (** rit_kwargs )
239+ self .rcit = RCIT_CInd (** rcit_kwargs )
240+
241+ def __call__ (self , X , Y , condition_set = None ):
242+ # Kernel-based conditional independence test.
243+ Xs , Ys , condition_set , cache_key = self .get_formatted_XYZ_and_cachekey (X , Y , condition_set )
244+ if cache_key in self .pvalue_cache : return self .pvalue_cache [cache_key ]
245+ p = self .rit .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ])[0 ] if len (condition_set ) == 0 else \
246+ self .rcit .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ], self .data [:, condition_set ])[0 ]
247+ self .pvalue_cache [cache_key ] = p
248+ return p
249+
196250class Chisq_or_Gsq (CIT_Base ):
197251 def __init__ (self , data , method_name , ** kwargs ):
198252 def _unique (column ):
0 commit comments