|
3 | 3 | import networkx as nx |
4 | 4 | from ndlib.models.DiffusionModel import DiffusionModel |
5 | 5 | import six |
| 6 | +import warnings |
| 7 | +import numpy as np |
| 8 | +import future.utils |
6 | 9 |
|
7 | 10 | __author__ = "Giulio Rossetti" |
8 | 11 | __license__ = "BSD-2-Clause" |
@@ -85,4 +88,116 @@ def execute_iterations(self, node_status=True): |
85 | 88 | system_status.append(its) |
86 | 89 | return system_status |
87 | 90 |
|
| 91 | + def set_initial_status(self, configuration): |
| 92 | + """ |
| 93 | + Set the initial model configuration |
| 94 | +
|
| 95 | + :param configuration: a ```ndlib.models.ModelConfig.Configuration``` object |
| 96 | + """ |
| 97 | + |
| 98 | + self.__validate_configuration(configuration) |
| 99 | + |
| 100 | + nodes_cfg = configuration.get_nodes_configuration() |
| 101 | + # Set additional node information |
| 102 | + |
| 103 | + for param, node_to_value in future.utils.iteritems(nodes_cfg): |
| 104 | + if len(node_to_value) < len(self.graph.nodes()): |
| 105 | + raise ConfigurationException({"message": "Not all nodes have a configuration specified"}) |
| 106 | + |
| 107 | + self.params['nodes'][param] = node_to_value |
| 108 | + |
| 109 | + edges_cfg = configuration.get_edges_configuration() |
| 110 | + # Set additional edges information |
| 111 | + for param, edge_to_values in future.utils.iteritems(edges_cfg): |
| 112 | + if len(edge_to_values) == len(self.graph.edges()): |
| 113 | + self.params['edges'][param] = {} |
| 114 | + for e in edge_to_values: |
| 115 | + self.params['edges'][param][e] = edge_to_values[e] |
| 116 | + |
| 117 | + # Set initial status |
| 118 | + model_status = configuration.get_model_configuration() |
| 119 | + |
| 120 | + for param, nodes in future.utils.iteritems(model_status): |
| 121 | + self.params['status'][param] = nodes |
| 122 | + for node in nodes: |
| 123 | + self.status[node] = self.available_statuses[param] |
| 124 | + |
| 125 | + # Set model additional information |
| 126 | + model_params = configuration.get_model_parameters() |
| 127 | + for param, val in future.utils.iteritems(model_params): |
| 128 | + self.params['model'][param] = val |
| 129 | + |
| 130 | + # Handle initial infection |
| 131 | + if 'Infected' not in self.params['status']: |
| 132 | + if 'percentage_infected' in self.params['model']: |
| 133 | + number_of_initial_infected = len(self.graph.nodes()) * float(self.params['model']['percentage_infected']) |
| 134 | + if number_of_initial_infected < 1: |
| 135 | + warnings.warn('Graph with less than 100 nodes: a single node will be set as infected') |
| 136 | + number_of_initial_infected = 1 |
| 137 | + |
| 138 | + available_nodes = [n for n in self.status if self.status[n] == 0] |
| 139 | + sampled_nodes = np.random.choice(available_nodes, int(number_of_initial_infected), replace=False) |
| 140 | + for k in sampled_nodes: |
| 141 | + self.status[k] = self.available_statuses['Infected'] |
| 142 | + |
| 143 | + self.initial_status = self.status |
| 144 | + |
| 145 | + def __validate_configuration(self, configuration): |
| 146 | + """ |
| 147 | + Validate the consistency of a Configuration object for the specific model |
| 148 | +
|
| 149 | + :param configuration: a Configuration object instance |
| 150 | + """ |
| 151 | + if "Infected" not in self.available_statuses: |
| 152 | + raise ConfigurationException("'Infected' status not defined.") |
| 153 | + |
| 154 | + # Checking mandatory parameters |
| 155 | + omp = set([k for k in self.parameters['model'].keys() if not self.parameters['model'][k]['optional']]) |
| 156 | + onp = set([k for k in self.parameters['nodes'].keys() if not self.parameters['nodes'][k]['optional']]) |
| 157 | + oep = set([k for k in self.parameters['edges'].keys() if not self.parameters['edges'][k]['optional']]) |
| 158 | + |
| 159 | + mdp = set(configuration.get_model_parameters().keys()) |
| 160 | + ndp = set(configuration.get_nodes_configuration().keys()) |
| 161 | + edp = set(configuration.get_edges_configuration().keys()) |
| 162 | + |
| 163 | + if len(omp) > 0: |
| 164 | + if len(omp & mdp) != len(omp): |
| 165 | + raise ConfigurationException({"message": "Missing mandatory model parameter(s)", "parameters": omp-mdp}) |
| 166 | + |
| 167 | + if len(onp) > 0: |
| 168 | + if len(onp & ndp) != len(onp): |
| 169 | + raise ConfigurationException({"message": "Missing mandatory node parameter(s)", "parameters": onp-ndp}) |
| 170 | + |
| 171 | + if len(oep) > 0: |
| 172 | + if len(oep & edp) != len(oep): |
| 173 | + raise ConfigurationException({"message": "Missing mandatory edge parameter(s)", "parameters": oep-edp}) |
| 174 | + |
| 175 | + # Checking optional parameters |
| 176 | + omp = set([k for k in self.parameters['model'].keys() if self.parameters['model'][k]['optional']]) |
| 177 | + onp = set([k for k in self.parameters['nodes'].keys() if self.parameters['nodes'][k]['optional']]) |
| 178 | + oep = set([k for k in self.parameters['edges'].keys() if self.parameters['edges'][k]['optional']]) |
| 179 | + |
| 180 | + if len(omp) > 0: |
| 181 | + for param in omp: |
| 182 | + if param not in mdp: |
| 183 | + configuration.add_model_parameter(param, self.parameters['model'][param]['default']) |
| 184 | + |
| 185 | + if len(onp) > 0: |
| 186 | + for param in onp: |
| 187 | + if param not in ndp: |
| 188 | + for nid in self.graph.nodes(): |
| 189 | + configuration.add_node_configuration(param, nid, self.parameters['nodes'][param]['default']) |
| 190 | + |
| 191 | + if len(oep) > 0: |
| 192 | + for param in oep: |
| 193 | + if param not in edp: |
| 194 | + for eid in self.graph.edges(): |
| 195 | + configuration.add_edge_configuration(param, eid, self.parameters['edges'][param]['default']) |
| 196 | + |
| 197 | + # Checking initial simulation status |
| 198 | + sts = set(configuration.get_model_configuration().keys()) |
| 199 | + if self.discrete_state and "Infected" not in sts and "percentage_infected" not in mdp: |
| 200 | + warnings.warn('Initial infection missing: a random sample of 5% of graph nodes will be set as infected') |
| 201 | + self.params['model']["percentage_infected"] = 0.05 |
| 202 | + |
88 | 203 | iteration_bunch = execute_snapshots |
0 commit comments