Skip to content

Commit c6f76a6

Browse files
committed
[UPDATE] SAM (2/2) + Doc + travis
1 parent 4a545a8 commit c6f76a6

File tree

5 files changed

+49
-62
lines changed

5 files changed

+49
-62
lines changed

.travis.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
if: (branch != master) OR (fork = true)
3030
name: "Build and test"
3131
script:
32-
- docker pull fentechai/nv-cdt-base:20.07
32+
- docker pull fentechai/nv-cdt-base:21.01
3333
- docker build --build-arg python=3.6 --build-arg spy=36 -t fentechai/cdt-dev .
3434
- if [[ $TRAVIS_PULL_REQUEST == "false" ]]; then
3535
docker run -e CODECOV_TOKEN --rm fentechai/cdt-dev /bin/sh -c 'cd /CDT && pip3 uninstall cdt -y && python3 setup.py install develop --user && pytest --cov-report= --cov=./cdt && codecov --token $CODECOV_TOKEN';
@@ -62,7 +62,7 @@ jobs:
6262
name: "Build and push 3.6 image for testing"
6363
script:
6464
- echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin
65-
- docker pull fentechai/cdt-base:20.07
65+
- docker pull fentechai/cdt-base:21.01
6666
- docker build --build-arg python=3.6 --build-arg spy=36 -t fentechai/cdt-test .
6767
- docker push fentechai/cdt-test
6868
- stage: test
@@ -97,7 +97,7 @@ jobs:
9797
- git pull
9898
- VERSION_NEW=$(cat setup.py| grep version | cut -c 20- | rev | cut -c 3- | rev)
9999
- echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin
100-
- docker pull fentechai/cdt-base:20.07
100+
- docker pull fentechai/cdt-base:21.01
101101
- docker build --build-arg python=3.6 --build-arg spy=36 -t fentechai/cdt:$VERSION_NEW .
102102
- docker push fentechai/cdt:$VERSION_NEW
103103
- docker tag fentechai/cdt:$VERSION_NEW fentechai/cdt:latest
@@ -108,7 +108,7 @@ jobs:
108108
- git pull
109109
- VERSION_NEW=$(cat setup.py| grep version | cut -c 20- | rev | cut -c 3- | rev)
110110
- echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin
111-
- docker pull fentechai/nv-cdt-base:20.07
111+
- docker pull fentechai/nv-cdt-base:21.01
112112
- docker build --build-arg python=3.6 --build-arg spy=36 -f nv-Dockerfile -t fentechai/nv-cdt:$VERSION_NEW .
113113
- docker push fentechai/nv-cdt:$VERSION_NEW
114114
- docker tag fentechai/nv-cdt:$VERSION_NEW fentechai/nv-cdt:latest

cdt/causality/graph/SAM.py

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
from .model import GraphModel
3737
from ...utils.parallel import parallel_run
3838
from ...utils.loss import notears_constr
39-
from ...utils.torch import ChannelBatchNorm1d, MatrixSampler, Linear3D
39+
from ...utils.torch import (ChannelBatchNorm1d, MatrixSampler,
40+
Linear3D, ParallelBatchNorm1d,
41+
SimpleMatrixConnection)
4042
from ...utils.Settings import SETTINGS
4143

4244

@@ -48,7 +50,7 @@ def permutation_matrix(self, skeleton, data_shape, max_dim):
4850

4951
for channel in range(self.nb_vars):
5052
perm_matrix = skeleton[:, channel] * th.eye(data_shape[1],data_shape[1])
51-
skeleton_list = [i for i in th.unbind(perm_matrix, 1) if len(th.nonzero(i)) > 0]
53+
skeleton_list = [i for i in th.unbind(perm_matrix, 1) if th.count_nonzero(i) > 0]
5254
perm_matrix = th.stack(skeleton_list, 1) if len(skeleton_list)>0 else th.zeros(data_shape[1], 1)
5355
reshape_skeleton[channel, :, :perm_matrix.shape[1]] = perm_matrix
5456

@@ -190,12 +192,9 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
190192
sampletype="sigmoidproba",
191193
dagstart=0, dagloss=False,
192194
dagpenalization=0.05, dagpenalization_increase=0.0,
193-
categorical_threshold=50, use_filter=False,
194-
filter_threshold=0.5, dag_threshold=0.5,
195+
categorical_threshold=50,
195196
linear=False, numberHiddenLayersG=2, numberHiddenLayersD=2, idx=0):
196197

197-
d_str = "Epoch: {} -- Disc: {:.4f} -- Total: {:.4f} -- Gen: {:.4f} -- L1: {:.4f}"
198-
# print("KLPenal:{}, fganLoss:{}".format(KLpenalization, fganLoss))
199198
list_nodes = list(in_data.columns)
200199
if is_mixed:
201200
onehotdata = []
@@ -218,8 +217,6 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
218217
if batch_size == -1:
219218
batch_size = data.shape[0]
220219

221-
lambda2_sauv = lambda2
222-
223220
lambda1 = lambda1/data.shape[0]
224221
lambda2 = lambda2/data.shape[0]
225222

@@ -331,7 +328,6 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
331328
gen_loss = -th.mean(th.exp(disc_vars_g - 1), [0, 2]).sum()
332329

333330
filters = graph_sampler.get_proba()
334-
335331
struc_loss = lambda1*drawn_graph.sum()
336332

337333
if linear :
@@ -342,7 +338,7 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
342338

343339

344340
elif functionalComplexity=="l2_norm":
345-
l2_reg = th.tensor(0.).to(device)
341+
l2_reg = th.Tensor([0.]).to(device)
346342
for param in sam.parameters():
347343
l2_reg += th.norm(param)
348344

@@ -352,14 +348,6 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu",
352348

353349

354350
# Optional: prune edges and sam parameters before dag search
355-
if epoch == int(train*dagstart) and use_filter:
356-
ones_tensor = th.ones(len(list_nodes),len(list_nodes))
357-
zeros_tensor = th.zeros(len(list_nodes),len(list_nodes))
358-
if not linear:
359-
skeleton = th.where(filters.cpu() > filter_threshold, ones_tensor, zeros_tensor)
360-
sam.apply_filter(skeleton, (batch_size, cols), device)
361-
graph_sampler.set_skeleton(skeleton.to(device))
362-
g_optimizer = th.optim.Adam(list(sam.parameters()), lr=lr_gen)
363351

364352
if dagloss and epoch > train * dagstart:
365353
dag_constraint = notears_constr(filters*filters)
@@ -397,7 +385,7 @@ class SAM(GraphModel):
397385
independencies. the first version of SAM without DAG constraint is available
398386
as ``SAMv1``.
399387
400-
**Data Type:** Continuous, Mixed (Experimental)
388+
**Data Type:** Continuous, (Mixed - Experimental)
401389
402390
**Assumptions:** The class of generative models is not restricted with a
403391
hard contraint, but with soft constraints parametrized with the ``lambda1``
@@ -409,33 +397,39 @@ class SAM(GraphModel):
409397
Args:
410398
lr (float): Learning rate of the generators
411399
dlr (float): Learning rate of the discriminator
400+
mixed_data (bool): Experimental -- Enable for mixed-type datasets
412401
lambda1 (float): L0 penalization coefficient on the causal filters
413-
lambda2 (float): L0 penalization coefficient on the hidden units of the
402+
lambda2 (float): L2 penalization coefficient on the weights of the
414403
neural network
415404
nh (int): Number of hidden units in the generators' hidden layers
416405
(regularized with lambda2)
417-
dnh (int): Number of hidden units in the discriminator's hidden layer
406+
dnh (int): Number of hidden units in the discriminator's hidden layers
418407
train_epochs (int): Number of training epochs
419408
test_epochs (int): Number of test epochs (saving and averaging
420409
the causal filters)
421-
batch_size (int): Size of the batches to be fed to the SAM model.
422-
Defaults to full-batch.
410+
batch_size (int): Size of the batches to be fed to the SAM model
411+
Defaults to full-batch
423412
losstype (str): type of the loss to be used (either 'fgan' (default),
424-
'gan' or 'mse').
425-
hlayers (int): Defines the number of hidden layers in the discriminator.
426-
dagloss (bool): Activate the DAG with No-TEARS constraint.
413+
'gan' or 'mse')
414+
dagloss (bool): Activate the DAG with No-TEARS constraint
427415
dagstart (float): Controls when the DAG constraint is to be introduced
428416
in the training (float ranging from 0 to 1, 0 denotes the start of
429-
the training and 1 the end).
430-
dagpenalisation (float): Initial value of the DAG constraint.
417+
the training and 1 the end)
418+
dagpenalisation (float): Initial value of the DAG constraint
431419
dagpenalisation_increase (float): Increase incrementally at each epoch
432-
the coefficient of the constraint.
433-
linear (bool): If true, all generators are set to be linear generators.
434-
nruns (int): Number of runs to be made for causal estimation.
435-
Recommended: >=32 for optimal performance.
436-
njobs (int): Numbers of jobs to be run in Parallel.
437-
Recommended: 1 if no GPU available, 2*number of GPUs else.
438-
gpus (int): Number of available GPUs for the algorithm.
420+
the coefficient of the constraint
421+
functional_complexity (str): Type of functional complexity penalization
422+
(choose between 'l2_norm' and 'n_hidden_units')
423+
hlayers (int): Defines the number of hidden layers in the generators
424+
dhlayers (int): Defines the number of hidden layers in the discriminator
425+
sampling_type (str): Type of sampling used in the structural gates of the
426+
model (choose between 'sigmoid', 'sigmoid_proba' and 'gumble_proba')
427+
linear (bool): If true, all generators are set to be linear generators
428+
nruns (int): Number of runs to be made for causal estimation
429+
Recommended: >=32 for optimal performance
430+
njobs (int): Numbers of jobs to be run in Parallel
431+
Recommended: 1 if no GPU available, 2*number of GPUs else
432+
gpus (int): Number of available GPUs for the algorithm
439433
verbose (bool): verbose mode
440434
441435
.. note::
@@ -465,14 +459,13 @@ class SAM(GraphModel):
465459
def __init__(self, lr=0.01, dlr=0.001, mixed_data=False,
466460
lambda1=10, lambda2=0.001,
467461
nh=20, dnh=200,
468-
train_epochs=3000, test_epochs=1000, batchsize=-1,
469-
losstype="fgan", dagstart=0.5, dagloss=True,
462+
train_epochs=3000, test_epochs=1000, batch_size=-1,
463+
losstype="fgan", dagloss=True, dagstart=0.5,
470464
dagpenalization=0,
471-
dagpenalization_increase=0.01, use_filter=False,
472-
filter_threshold=.5,
465+
dagpenalization_increase=0.01,
473466
functional_complexity='l2_norm', hlayers=2, dhlayers=2,
474-
sampling_type='sigmoidproba', linear=False,
475-
njobs=None, gpus=None, verbose=None, nruns=8):
467+
sampling_type='sigmoidproba', linear=False, nruns=8,
468+
njobs=None, gpus=None, verbose=None):
476469

477470
"""Init and parametrize the SAM model."""
478471
super(SAM, self).__init__()
@@ -485,19 +478,17 @@ def __init__(self, lr=0.01, dlr=0.001, mixed_data=False,
485478
self.dnh = dnh
486479
self.train = train_epochs
487480
self.test = test_epochs
488-
self.batchsize = batchsize
481+
self.batch_size = batch_size
489482
self.dagstart = dagstart
490483
self.dagloss = dagloss
491484
self.dagpenalization = dagpenalization
492485
self.dagpenalization_increase = dagpenalization_increase
493-
self.use_filter = use_filter
494-
self.filter_threshold = filter_threshold
495486
self.losstype = losstype
496487
self.functionalComplexity = functional_complexity
497488
self.sampletype = sampling_type
498489
self.linear = linear
499-
self.numberHiddenLayersD = hlayers
500-
self.numberHiddenLayersG = dhlayers
490+
self.numberHiddenLayersG = hlayers
491+
self.numberHiddenLayersD = dhlayers
501492
self.njobs = SETTINGS.get_default(njobs=njobs)
502493
self.gpus = SETTINGS.get_default(gpu=gpus)
503494
self.verbose = SETTINGS.get_default(verbose=verbose)
@@ -529,13 +520,11 @@ def predict(self, data, graph=None,
529520
lambda1=self.lambda1, lambda2=self.lambda2,
530521
nh=self.nh, dnh=self.dnh,
531522
train=self.train,
532-
test=self.test, batch_size=self.batchsize,
523+
test=self.test, batch_size=self.batch_size,
533524
dagstart=self.dagstart,
534525
dagloss=self.dagloss,
535526
dagpenalization=self.dagpenalization,
536527
dagpenalization_increase=self.dagpenalization_increase,
537-
use_filter=self.use_filter,
538-
filter_threshold=self.filter_threshold,
539528
losstype=self.losstype,
540529
functionalComplexity=self.functionalComplexity,
541530
sampletype=self.sampletype,
@@ -552,13 +541,11 @@ def predict(self, data, graph=None,
552541
lambda1=self.lambda1, lambda2=self.lambda2,
553542
nh=self.nh, dnh=self.dnh,
554543
train=self.train,
555-
test=self.test, batch_size=self.batchsize,
544+
test=self.test, batch_size=self.batch_size,
556545
dagstart=self.dagstart,
557546
dagloss=self.dagloss,
558547
dagpenalization=self.dagpenalization,
559548
dagpenalization_increase=self.dagpenalization_increase,
560-
use_filter=self.use_filter,
561-
filter_threshold=self.filter_threshold,
562549
losstype=self.losstype,
563550
functionalComplexity=self.functionalComplexity,
564551
sampletype=self.sampletype,

cdt/utils/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from torch.nn import Parameter
3131
from torch.nn.modules.batchnorm import _BatchNorm
3232
import torch.distributions.relaxed_bernoulli as relaxed_bernoulli
33-
from torch.distributions.transformed_distribution import TransformedDistribution
33+
from torch.distributions.transformed_distribution import TransformedDistribution
3434
from torch.distributions.transforms import SigmoidTransform,AffineTransform
3535
from torch.distributions.uniform import Uniform
3636

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ scikit-learn>=0.21.1
44
joblib>=0.13.2
55
pandas>=0.24.1
66
networkx>=2.3
7-
torch
7+
torch>1.7.0
88
tqdm>4.0.0
99
GPUtil>=1.4.0
1010
statsmodels>=0.9.0

tests/scripts/test_causality_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_SAMv1():
6161

6262

6363
if __name__ == "__main__":
64-
# test_SAM()
65-
test_directed()
66-
test_undirected()
67-
test_graph()
64+
test_SAM()
65+
# test_directed()
66+
# test_undirected()
67+
# test_graph()

0 commit comments

Comments
 (0)