Skip to content

Commit db019e3

Browse files
authored
Introduce trainer module and refactor public APIs and docs (#68)
* Add new regularization classes and corresponding tests for ElasticNet, Huber, Group Lasso, Total Variation, Max Norm, Entropy, Orthogonal, Spectral Norm, and various prior distributions * Add trainer module with training framework and progress bar utilities * Update test descriptions for fit_hyper parameter in regularization tests * delete braintools.param module * refactor: reorganize imports and enhance module documentation for clarity and usability * bump: update version to 0.1.7 * refactor: update import statement for matfile module
1 parent 206c5ec commit db019e3

36 files changed

+10998
-4670
lines changed

braintools/__init__.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515

1616

17-
__version__ = "0.1.6"
17+
__version__ = "0.1.7"
1818
__version_info__ = tuple(map(int, __version__.split(".")))
1919

2020
from . import conn
@@ -23,28 +23,62 @@
2323
from . import input
2424
from . import metric
2525
from . import optim
26-
from . import param
2726
from . import quad
2827
from . import surrogate
28+
from . import trainer
2929
from . import tree
3030
from . import visualize
31-
from ._spike_encoder import *
32-
from ._spike_encoder import __all__ as encoder_all
33-
from ._spike_operation import *
34-
from ._spike_operation import __all__ as operation_all
31+
from ._spike_encoder import (
32+
LatencyEncoder,
33+
RateEncoder,
34+
PoissonEncoder,
35+
PopulationEncoder,
36+
BernoulliEncoder,
37+
DeltaEncoder,
38+
StepCurrentEncoder,
39+
SpikeCountEncoder,
40+
TemporalEncoder,
41+
RankOrderEncoder,
42+
)
43+
from ._spike_operation import (
44+
spike_bitwise_or,
45+
spike_bitwise_and,
46+
spike_bitwise_iand,
47+
spike_bitwise_not,
48+
spike_bitwise_xor,
49+
spike_bitwise_ixor,
50+
spike_bitwise,
51+
)
3552

3653
__all__ = [
37-
'conn',
38-
'input',
39-
'init',
40-
'file',
41-
'metric',
42-
'visualize',
43-
'optim',
44-
'tree',
45-
'quad',
46-
'surrogate',
47-
'param',
48-
] + encoder_all + operation_all
54+
'conn',
55+
'input',
56+
'init',
57+
'file',
58+
'metric',
59+
'visualize',
60+
'optim',
61+
'trainer',
62+
'tree',
63+
'quad',
64+
'surrogate',
4965

50-
del encoder_all, operation_all
66+
'LatencyEncoder',
67+
'RateEncoder',
68+
'PoissonEncoder',
69+
'PopulationEncoder',
70+
'BernoulliEncoder',
71+
'DeltaEncoder',
72+
'StepCurrentEncoder',
73+
'SpikeCountEncoder',
74+
'TemporalEncoder',
75+
'RankOrderEncoder',
76+
77+
'spike_bitwise_or',
78+
'spike_bitwise_and',
79+
'spike_bitwise_iand',
80+
'spike_bitwise_not',
81+
'spike_bitwise_xor',
82+
'spike_bitwise_ixor',
83+
'spike_bitwise',
84+
]

braintools/conn/__init__.py

Lines changed: 184 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,187 @@
9898
9999
"""
100100

101-
from ._base import *
102-
from ._base import __all__ as base_all
103-
from ._biological import *
104-
from ._biological import __all__ as biological_all
105-
from ._compartment import *
106-
from ._compartment import __all__ as comp_all
107-
from ._kernel import *
108-
from ._kernel import __all__ as kernel_all
109-
from ._random import *
110-
from ._random import __all__ as point_all
111-
from ._regular import *
112-
from ._regular import __all__ as regular_all
113-
from ._spatial import *
114-
from ._spatial import __all__ as spatial_all
115-
from ._topological import *
116-
from ._topological import __all__ as topological_all
117-
118-
__all__ = base_all + comp_all + kernel_all + point_all + spatial_all + topological_all + biological_all + regular_all
119-
del base_all, comp_all, kernel_all, point_all, spatial_all, topological_all, biological_all
120-
del regular_all
101+
# Base classes and utilities
102+
from ._base import (
103+
ConnectionResult,
104+
Connectivity,
105+
PointConnectivity,
106+
MultiCompartmentConnectivity,
107+
ScaledConnectivity,
108+
CompositeConnectivity,
109+
)
110+
111+
# Biological patterns
112+
from ._biological import (
113+
ExcitatoryInhibitory,
114+
)
115+
116+
# Multi-compartment connectivity
117+
from ._compartment import (
118+
# Compartment type constants
119+
SOMA,
120+
BASAL_DENDRITE,
121+
APICAL_DENDRITE,
122+
AXON,
123+
124+
# Basic compartment patterns
125+
CompartmentSpecific,
126+
AllToAllCompartments,
127+
128+
# Anatomical targeting patterns
129+
SomaToDendrite,
130+
AxonToSoma,
131+
DendriteToSoma,
132+
AxonToDendrite,
133+
DendriteToDendrite,
134+
135+
# Morphology-aware patterns
136+
ProximalTargeting,
137+
DistalTargeting,
138+
BranchSpecific,
139+
MorphologyDistance,
140+
141+
# Dendritic patterns
142+
DendriticTree,
143+
BasalDendriteTargeting,
144+
ApicalDendriteTargeting,
145+
DendriticIntegration,
146+
147+
# Axonal patterns
148+
AxonalProjection,
149+
AxonalBranching,
150+
AxonalArborization,
151+
TopographicProjection,
152+
153+
# Synaptic patterns
154+
SynapticPlacement,
155+
SynapticClustering,
156+
157+
# Custom patterns
158+
CustomCompartment,
159+
)
160+
161+
# Kernel-based connectivity
162+
from ._kernel import (
163+
Conv2dKernel,
164+
GaussianKernel,
165+
GaborKernel,
166+
DoGKernel,
167+
MexicanHat,
168+
SobelKernel,
169+
LaplacianKernel,
170+
CustomKernel,
171+
)
172+
173+
# Random connectivity patterns
174+
from ._random import (
175+
Random,
176+
FixedProb,
177+
ClusteredRandom,
178+
)
179+
180+
# Regular patterns
181+
from ._regular import (
182+
AllToAll,
183+
OneToOne,
184+
)
185+
186+
# Spatial connectivity patterns
187+
from ._spatial import (
188+
DistanceDependent,
189+
Gaussian,
190+
Exponential,
191+
Ring,
192+
Grid2d,
193+
RadialPatches,
194+
)
195+
196+
# Topological patterns
197+
from ._topological import (
198+
SmallWorld,
199+
ScaleFree,
200+
Regular,
201+
ModularRandom,
202+
ModularGeneral,
203+
HierarchicalRandom,
204+
CorePeripheryRandom,
205+
)
206+
207+
__all__ = [
208+
# Base classes
209+
'ConnectionResult',
210+
'Connectivity',
211+
'PointConnectivity',
212+
'MultiCompartmentConnectivity',
213+
'ScaledConnectivity',
214+
'CompositeConnectivity',
215+
216+
# Biological patterns
217+
'ExcitatoryInhibitory',
218+
219+
# Compartment constants
220+
'SOMA',
221+
'BASAL_DENDRITE',
222+
'APICAL_DENDRITE',
223+
'AXON',
224+
225+
# Multi-compartment patterns
226+
'CompartmentSpecific',
227+
'AllToAllCompartments',
228+
'SomaToDendrite',
229+
'AxonToSoma',
230+
'DendriteToSoma',
231+
'AxonToDendrite',
232+
'DendriteToDendrite',
233+
'ProximalTargeting',
234+
'DistalTargeting',
235+
'BranchSpecific',
236+
'MorphologyDistance',
237+
'DendriticTree',
238+
'BasalDendriteTargeting',
239+
'ApicalDendriteTargeting',
240+
'DendriticIntegration',
241+
'AxonalProjection',
242+
'AxonalBranching',
243+
'AxonalArborization',
244+
'TopographicProjection',
245+
'SynapticPlacement',
246+
'SynapticClustering',
247+
'CustomCompartment',
248+
249+
# Kernel patterns
250+
'Conv2dKernel',
251+
'GaussianKernel',
252+
'GaborKernel',
253+
'DoGKernel',
254+
'MexicanHat',
255+
'SobelKernel',
256+
'LaplacianKernel',
257+
'CustomKernel',
258+
259+
# Random patterns
260+
'Random',
261+
'FixedProb',
262+
'ClusteredRandom',
263+
264+
# Regular patterns
265+
'AllToAll',
266+
'OneToOne',
267+
268+
# Spatial patterns
269+
'DistanceDependent',
270+
'Gaussian',
271+
'Exponential',
272+
'Ring',
273+
'Grid2d',
274+
'RadialPatches',
275+
276+
# Topological patterns
277+
'SmallWorld',
278+
'ScaleFree',
279+
'Regular',
280+
'ModularRandom',
281+
'ModularGeneral',
282+
'HierarchicalRandom',
283+
'CorePeripheryRandom',
284+
]

0 commit comments

Comments
 (0)