Skip to content

Commit 61d0557

Browse files
committed
Adjust code for jax==0.7.1, set new version number
1 parent ae297d9 commit 61d0557

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
![Figure 1](docs/figure1.png)
33
This is the official implementation of our paper [Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models
44
](https://arxiv.org/abs/2404.18508).
5+
## Version 0.2 released
6+
- Fixed a bug in the event tokenization for DVS. Improves our DVS Gestures result to 99.2% accuracy (without using spatial convolutions at all)
7+
- Updated DVS configs
8+
- Compatibility with recent JAX versions >= 0.7.1
9+
10+
## Introduction
511
The core motivation for this work was the irregular time-series modeling problem presented in the paper [Simplified State Space Layers for Sequence Modeling
612
](https://arxiv.org/abs/2208.04933).
713
We acknowledge the awesome [S5 project](https://github.com/lindermanlab/S5) and the trainer class provided by this [UvA tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/guide4/Research_Projects_with_JAX.html), which highly influenced our code.

event_ssm/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def get_first_device(x):
152152
def print_model_size(params, name=''):
153153
fn_is_complex = lambda x: x.dtype in [np.complex64, np.complex128]
154154
param_sizes = map_nested_fn(lambda k, param: param.size * (2 if fn_is_complex(param) else 1))(params)
155-
total_params_size = sum(jax.tree_leaves(param_sizes))
155+
total_params_size = sum(jax.tree.leaves(param_sizes))
156156
print('[*] Model parameter count:', total_params_size)
157157

158158

event_ssm/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
self.best_eval_metrics = {}
6262

6363
# logger details
64-
self.log_dir = os.path.join(self.log_config.log_dir)
64+
self.log_dir = os.path.abspath(os.path.join(self.log_config.log_dir))
6565
print('[*] Logging to', self.log_dir)
6666

6767
if not os.path.isdir(self.log_dir):
@@ -72,7 +72,7 @@ def __init__(
7272
os.makedirs(os.path.join(self.log_dir, 'checkpoints'))
7373

7474
num_parameters = int(sum(
75-
[arr.size for arr in jax.tree_flatten(self.train_state.params)[0]
75+
[arr.size for arr in jax.tree.flatten(self.train_state.params)[0]
7676
if isinstance(arr, Array)]
7777
) / self.world_size)
7878
print("[*] Number of model parameters:", num_parameters)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name='Event-based-SSM',
55
packages=['event_ssm'],
6-
version='0.1',
6+
version='0.2',
77
description='Event-stream modeling with state-space models',
88
author='Mark Schoene',
99
author_email='[email protected]',

0 commit comments

Comments
 (0)