3636from .model import GraphModel
3737from ...utils .parallel import parallel_run
3838from ...utils .loss import notears_constr
39- from ...utils .torch import ChannelBatchNorm1d , MatrixSampler , Linear3D
39+ from ...utils .torch import (ChannelBatchNorm1d , MatrixSampler ,
40+ Linear3D , ParallelBatchNorm1d ,
41+ SimpleMatrixConnection )
4042from ...utils .Settings import SETTINGS
4143
4244
@@ -48,7 +50,7 @@ def permutation_matrix(self, skeleton, data_shape, max_dim):
4850
4951 for channel in range (self .nb_vars ):
5052 perm_matrix = skeleton [:, channel ] * th .eye (data_shape [1 ],data_shape [1 ])
51- skeleton_list = [i for i in th .unbind (perm_matrix , 1 ) if len ( th .nonzero ( i ) ) > 0 ]
53+ skeleton_list = [i for i in th .unbind (perm_matrix , 1 ) if th .count_nonzero ( i ) > 0 ]
5254 perm_matrix = th .stack (skeleton_list , 1 ) if len (skeleton_list )> 0 else th .zeros (data_shape [1 ], 1 )
5355 reshape_skeleton [channel , :, :perm_matrix .shape [1 ]] = perm_matrix
5456
@@ -190,12 +192,9 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
190192 sampletype = "sigmoidproba" ,
191193 dagstart = 0 , dagloss = False ,
192194 dagpenalization = 0.05 , dagpenalization_increase = 0.0 ,
193- categorical_threshold = 50 , use_filter = False ,
194- filter_threshold = 0.5 , dag_threshold = 0.5 ,
195+ categorical_threshold = 50 ,
195196 linear = False , numberHiddenLayersG = 2 , numberHiddenLayersD = 2 , idx = 0 ):
196197
197- d_str = "Epoch: {} -- Disc: {:.4f} -- Total: {:.4f} -- Gen: {:.4f} -- L1: {:.4f}"
198- # print("KLPenal:{}, fganLoss:{}".format(KLpenalization, fganLoss))
199198 list_nodes = list (in_data .columns )
200199 if is_mixed :
201200 onehotdata = []
@@ -218,8 +217,6 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
218217 if batch_size == - 1 :
219218 batch_size = data .shape [0 ]
220219
221- lambda2_sauv = lambda2
222-
223220 lambda1 = lambda1 / data .shape [0 ]
224221 lambda2 = lambda2 / data .shape [0 ]
225222
@@ -331,7 +328,6 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
331328 gen_loss = - th .mean (th .exp (disc_vars_g - 1 ), [0 , 2 ]).sum ()
332329
333330 filters = graph_sampler .get_proba ()
334-
335331 struc_loss = lambda1 * drawn_graph .sum ()
336332
337333 if linear :
@@ -342,7 +338,7 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
342338
343339
344340 elif functionalComplexity == "l2_norm" :
345- l2_reg = th .tensor ( 0. ).to (device )
341+ l2_reg = th .Tensor ([ 0. ] ).to (device )
346342 for param in sam .parameters ():
347343 l2_reg += th .norm (param )
348344
@@ -352,14 +348,6 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
352348
353349
354350 # Optional: prune edges and sam parameters before dag search
355- if epoch == int (train * dagstart ) and use_filter :
356- ones_tensor = th .ones (len (list_nodes ),len (list_nodes ))
357- zeros_tensor = th .zeros (len (list_nodes ),len (list_nodes ))
358- if not linear :
359- skeleton = th .where (filters .cpu () > filter_threshold , ones_tensor , zeros_tensor )
360- sam .apply_filter (skeleton , (batch_size , cols ), device )
361- graph_sampler .set_skeleton (skeleton .to (device ))
362- g_optimizer = th .optim .Adam (list (sam .parameters ()), lr = lr_gen )
363351
364352 if dagloss and epoch > train * dagstart :
365353 dag_constraint = notears_constr (filters * filters )
@@ -397,7 +385,7 @@ class SAM(GraphModel):
397385 independencies. the first version of SAM without DAG constraint is available
398386 as ``SAMv1``.
399387
400- **Data Type:** Continuous, Mixed ( Experimental)
388+ **Data Type:** Continuous, ( Mixed - Experimental)
401389
402390 **Assumptions:** The class of generative models is not restricted with a
403391 hard contraint, but with soft constraints parametrized with the ``lambda1``
@@ -409,33 +397,39 @@ class SAM(GraphModel):
409397 Args:
410398 lr (float): Learning rate of the generators
411399 dlr (float): Learning rate of the discriminator
400+ mixed_data (bool): Experimental -- Enable for mixed-type datasets
412401 lambda1 (float): L0 penalization coefficient on the causal filters
413- lambda2 (float): L0 penalization coefficient on the hidden units of the
402+ lambda2 (float): L2 penalization coefficient on the weights of the
414403 neural network
415404 nh (int): Number of hidden units in the generators' hidden layers
416405 (regularized with lambda2)
417- dnh (int): Number of hidden units in the discriminator's hidden layer
406+ dnh (int): Number of hidden units in the discriminator's hidden layers
418407 train_epochs (int): Number of training epochs
419408 test_epochs (int): Number of test epochs (saving and averaging
420409 the causal filters)
421- batch_size (int): Size of the batches to be fed to the SAM model.
422- Defaults to full-batch.
410+ batch_size (int): Size of the batches to be fed to the SAM model
411+ Defaults to full-batch
423412 losstype (str): type of the loss to be used (either 'fgan' (default),
424- 'gan' or 'mse').
425- hlayers (int): Defines the number of hidden layers in the discriminator.
426- dagloss (bool): Activate the DAG with No-TEARS constraint.
413+ 'gan' or 'mse')
414+ dagloss (bool): Activate the DAG with No-TEARS constraint
427415 dagstart (float): Controls when the DAG constraint is to be introduced
428416 in the training (float ranging from 0 to 1, 0 denotes the start of
429- the training and 1 the end).
430- dagpenalisation (float): Initial value of the DAG constraint.
417+ the training and 1 the end)
418+ dagpenalisation (float): Initial value of the DAG constraint
431419 dagpenalisation_increase (float): Increase incrementally at each epoch
432- the coefficient of the constraint.
433- linear (bool): If true, all generators are set to be linear generators.
434- nruns (int): Number of runs to be made for causal estimation.
435- Recommended: >=32 for optimal performance.
436- njobs (int): Numbers of jobs to be run in Parallel.
437- Recommended: 1 if no GPU available, 2*number of GPUs else.
438- gpus (int): Number of available GPUs for the algorithm.
420+ the coefficient of the constraint
421+ functional_complexity (str): Type of functional complexity penalization
422+ (choose between 'l2_norm' and 'n_hidden_units')
423+ hlayers (int): Defines the number of hidden layers in the generators
424+ dhlayers (int): Defines the number of hidden layers in the discriminator
425+ sampling_type (str): Type of sampling used in the structural gates of the
426+ model (choose between 'sigmoid', 'sigmoid_proba' and 'gumble_proba')
427+ linear (bool): If true, all generators are set to be linear generators
428+ nruns (int): Number of runs to be made for causal estimation
429+ Recommended: >=32 for optimal performance
430+ njobs (int): Numbers of jobs to be run in Parallel
431+ Recommended: 1 if no GPU available, 2*number of GPUs else
432+ gpus (int): Number of available GPUs for the algorithm
439433 verbose (bool): verbose mode
440434
441435 .. note::
@@ -465,14 +459,13 @@ class SAM(GraphModel):
465459 def __init__ (self , lr = 0.01 , dlr = 0.001 , mixed_data = False ,
466460 lambda1 = 10 , lambda2 = 0.001 ,
467461 nh = 20 , dnh = 200 ,
468- train_epochs = 3000 , test_epochs = 1000 , batchsize = - 1 ,
469- losstype = "fgan" , dagstart = 0.5 , dagloss = True ,
462+ train_epochs = 3000 , test_epochs = 1000 , batch_size = - 1 ,
463+ losstype = "fgan" , dagloss = True , dagstart = 0.5 ,
470464 dagpenalization = 0 ,
471- dagpenalization_increase = 0.01 , use_filter = False ,
472- filter_threshold = .5 ,
465+ dagpenalization_increase = 0.01 ,
473466 functional_complexity = 'l2_norm' , hlayers = 2 , dhlayers = 2 ,
474- sampling_type = 'sigmoidproba' , linear = False ,
475- njobs = None , gpus = None , verbose = None , nruns = 8 ):
467+ sampling_type = 'sigmoidproba' , linear = False , nruns = 8 ,
468+ njobs = None , gpus = None , verbose = None ):
476469
477470 """Init and parametrize the SAM model."""
478471 super (SAM , self ).__init__ ()
@@ -485,19 +478,17 @@ def __init__(self, lr=0.01, dlr=0.001, mixed_data=False,
485478 self .dnh = dnh
486479 self .train = train_epochs
487480 self .test = test_epochs
488- self .batchsize = batchsize
481+ self .batch_size = batch_size
489482 self .dagstart = dagstart
490483 self .dagloss = dagloss
491484 self .dagpenalization = dagpenalization
492485 self .dagpenalization_increase = dagpenalization_increase
493- self .use_filter = use_filter
494- self .filter_threshold = filter_threshold
495486 self .losstype = losstype
496487 self .functionalComplexity = functional_complexity
497488 self .sampletype = sampling_type
498489 self .linear = linear
499- self .numberHiddenLayersD = hlayers
500- self .numberHiddenLayersG = dhlayers
490+ self .numberHiddenLayersG = hlayers
491+ self .numberHiddenLayersD = dhlayers
501492 self .njobs = SETTINGS .get_default (njobs = njobs )
502493 self .gpus = SETTINGS .get_default (gpu = gpus )
503494 self .verbose = SETTINGS .get_default (verbose = verbose )
@@ -529,13 +520,11 @@ def predict(self, data, graph=None,
529520 lambda1 = self .lambda1 , lambda2 = self .lambda2 ,
530521 nh = self .nh , dnh = self .dnh ,
531522 train = self .train ,
532- test = self .test , batch_size = self .batchsize ,
523+ test = self .test , batch_size = self .batch_size ,
533524 dagstart = self .dagstart ,
534525 dagloss = self .dagloss ,
535526 dagpenalization = self .dagpenalization ,
536527 dagpenalization_increase = self .dagpenalization_increase ,
537- use_filter = self .use_filter ,
538- filter_threshold = self .filter_threshold ,
539528 losstype = self .losstype ,
540529 functionalComplexity = self .functionalComplexity ,
541530 sampletype = self .sampletype ,
@@ -552,13 +541,11 @@ def predict(self, data, graph=None,
552541 lambda1 = self .lambda1 , lambda2 = self .lambda2 ,
553542 nh = self .nh , dnh = self .dnh ,
554543 train = self .train ,
555- test = self .test , batch_size = self .batchsize ,
544+ test = self .test , batch_size = self .batch_size ,
556545 dagstart = self .dagstart ,
557546 dagloss = self .dagloss ,
558547 dagpenalization = self .dagpenalization ,
559548 dagpenalization_increase = self .dagpenalization_increase ,
560- use_filter = self .use_filter ,
561- filter_threshold = self .filter_threshold ,
562549 losstype = self .losstype ,
563550 functionalComplexity = self .functionalComplexity ,
564551 sampletype = self .sampletype ,
0 commit comments