Skip to content

Commit 2fa4ad9

Browse files
authored
Merge pull request #96 from iraikov/feature/register_population
Env: added support for registering new populations
2 parents d7c9ba1 + decedb2 commit 2fa4ad9

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "miv-simulator"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "Mind-In-Vitro simulator"
55
authors = []
66
dependencies = [

src/miv_simulator/env.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,26 @@ def load_celltypes(self) -> None:
889889
weights_dict["closure"] = clos
890890
synapses_dict["weights"] = weights_dicts
891891

892+
def register_population(self, population_name, population_cell_distribution):
893+
if population_name in self.Populations:
894+
return None
895+
896+
max_pop_enum = 0
897+
pop_offset = 0
898+
for this_pop_name, this_pop_enum in self.Populations.items():
899+
max_pop_enum = max(this_pop_enum, max_pop_enum)
900+
pop_offset += self.celltypes[this_pop_name]["num"]
901+
902+
pop_id = max_pop_enum + 1
903+
self.Populations[population_name] = pop_id
904+
cell_distribution = {}
905+
if "Cell Distribution" in self.geometry:
906+
cell_distribution = self.geometry["Cell Distribution"]
907+
else:
908+
self.geometry["Cell Distribution"] = population_cell_distribution
909+
cell_distribution[population_name] = population_cell_distribution
910+
return {"population_id": pop_id, "population_start_gid": pop_offset}
911+
892912
def clear(self):
893913
self.gidset = set()
894914
self.gjlist = []

src/miv_simulator/input_spike_trains.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def generate_input_spike_trains(
9393
logger.info(f"{comm.size} ranks have been allocated")
9494

9595
population_name = population.name
96+
start_gid = 0
97+
if hasattr(population, "start_gid"):
98+
start_gid = population.start_gid
9699

97100
soma_positions_dict = None
98101
if coords_path is not None:
@@ -186,6 +189,8 @@ def generate_input_spike_trains(
186189
feature_items = list(population.features.items())
187190
n_iter = comm.allreduce(len(feature_items), op=MPI.MAX)
188191

192+
logger.info(f"n_iter = {n_iter} feature_items = {feature_items}")
193+
189194
if not dry_run and rank == 0:
190195
if output_path is None:
191196
raise RuntimeError("generate_input_spike_trains: missing output_path")
@@ -198,6 +203,7 @@ def generate_input_spike_trains(
198203
for iter_count in range(n_iter):
199204
if iter_count < len(feature_items):
200205
gid, input_feature = feature_items[iter_count]
206+
gid += start_gid
201207
else:
202208
gid, input_feature = None, None
203209
if gid is not None:
@@ -218,13 +224,11 @@ def generate_input_spike_trains(
218224

219225
# Get spike response
220226
response = input_feature.get_response(processed_signal)
227+
if isinstance(response, list):
228+
response = np.concatenate(np.concatenate(response, dtype=np.float32))
221229

222230
if len(response) > 0:
223-
spikes_attr_dict[gid] = {
224-
output_spike_train_attr_name: np.concatenate(
225-
response, dtype=np.float32
226-
)
227-
}
231+
spikes_attr_dict[gid] = {output_spike_train_attr_name: response}
228232

229233
gid_count += 1
230234
if (iter_count > 0 and iter_count % write_every == 0) or (

0 commit comments

Comments
 (0)