|
15 | 15 | from lib.cuckoo.common.constants import CUCKOO_ROOT |
16 | 16 | from math import log |
17 | 17 |
|
| 18 | +if Config("cuckooml").cuckooml.plotting: |
| 19 | + try: |
| 20 | + import matplotlib.pyplot as plt |
| 21 | + import seaborn as sns |
| 22 | + except ImportError, e: |
| 23 | + print >> sys.stderr, "Some error while importing" |
| 24 | + print >> sys.stderr, e |
| 25 | + |
| 26 | + |
18 | 27 | try: |
19 | | - import matplotlib.pyplot as plt |
20 | 28 | import numpy as np |
21 | 29 | import pandas as pd |
22 | | - import seaborn as sns |
23 | 30 | from hdbscan import HDBSCAN |
24 | 31 | from sklearn import metrics |
25 | 32 | from sklearn.cluster import DBSCAN |
@@ -797,6 +804,14 @@ def filter_dataset(self, dataset=None, feature_coverage=0.1, |
797 | 804 |
|
798 | 805 | def detect_abnormal_behaviour(self, count_dataset=None, figures=True): |
799 | 806 | """Detect samples that behave significantly different than others.""" |
| 807 | + |
| 808 | + # Safety check for plotting |
| 809 | + if not Config("cuckooml").cuckooml.plotting and figures: |
| 810 | + print >> sys.stderr, "Warning: 'plotting' and 'figures' do not match. \ |
| 811 | + Plotting modules might not be imported." |
| 812 | + figures = False |
| 813 | + |
| 814 | + |
800 | 815 | if count_dataset is None: |
801 | 816 | # Pull all count features |
802 | 817 | count_features = self.feature_category(":count:") |
@@ -1133,6 +1148,14 @@ def performance_metric(clustering, labels, data, noise): |
1133 | 1148 |
|
1134 | 1149 | def clustering_label_distribution(self, clustering, labels, plot=False): |
1135 | 1150 | """Get statistics about number of ground truth labels per cluster.""" |
| 1151 | + |
| 1152 | + # Safety check for plotting |
| 1153 | + if not Config("cuckooml").cuckooml.plotting and plot: |
| 1154 | + print >> sys.stderr, "Warning: 'plotting' and 'plot' do not match.\ |
| 1155 | + Plotting modules might not be imported." |
| 1156 | + plot = False |
| 1157 | + |
| 1158 | + |
1136 | 1159 | cluster_ids = set(clustering["label"].tolist()) |
1137 | 1160 | labels_ids = set(labels["label"].tolist()) |
1138 | 1161 | cluster_distribution = {} |
|
0 commit comments