Skip to content

Commit e408d9b

Browse files
author
Alexander Ororbia
committed
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn into major_release_update
2 parents f4d47d4 + b51cfd0 commit e408d9b

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# %%
2+
3+
# Adapted from meta jepa
4+
5+
import math
6+
import numpy as np
7+
from multiprocessing import Value
8+
9+
class MaskCollator(object):
10+
11+
def __init__(
12+
self,
13+
cfgs_mask,
14+
crop_size=(224, 224),
15+
patch_size=(16, 16),
16+
):
17+
super(MaskCollator, self).__init__()
18+
19+
self.mask_generators = []
20+
for m in cfgs_mask:
21+
mask_generator = _MaskGenerator(
22+
crop_size=crop_size,
23+
patch_size=patch_size,
24+
pred_mask_scale=m.get('spatial_scale'),
25+
aspect_ratio=m.get('aspect_ratio'),
26+
npred=m.get('num_blocks'),
27+
max_keep=m.get('max_keep', None),
28+
)
29+
self.mask_generators.append(mask_generator)
30+
31+
def step(self):
32+
for mask_generator in self.mask_generators:
33+
mask_generator.step()
34+
35+
def __call__(self, batch):
36+
37+
batch_size = len(batch)
38+
39+
collated_masks_pred, collated_masks_enc = [], []
40+
for i, mask_generator in enumerate(self.mask_generators):
41+
masks_enc, masks_pred = mask_generator(batch_size)
42+
collated_masks_enc.append(masks_enc)
43+
collated_masks_pred.append(masks_pred)
44+
45+
return collated_masks_enc, collated_masks_pred
46+
47+
48+
class _MaskGenerator(object):
49+
50+
def __init__(
51+
self,
52+
crop_size=(224, 224),
53+
patch_size=(16, 16),
54+
pred_mask_scale=(0.2, 0.8),
55+
aspect_ratio=(0.3, 3.0),
56+
npred=1,
57+
max_keep=None,
58+
):
59+
super(_MaskGenerator, self).__init__()
60+
if not isinstance(crop_size, tuple):
61+
crop_size = (crop_size, ) * 2
62+
self.crop_size = crop_size
63+
self.height, self.width = crop_size[0] // patch_size[0], crop_size[1] // patch_size[1]
64+
65+
self.patch_size = patch_size
66+
self.aspect_ratio = aspect_ratio
67+
self.pred_mask_scale = pred_mask_scale
68+
self.npred = npred
69+
self.max_keep = max_keep
70+
self._itr_counter = Value('i', -1) # collator is shared across worker processes
71+
72+
def step(self):
73+
i = self._itr_counter
74+
with i.get_lock():
75+
i.value = (i.value + 1) % 2**16
76+
v = i.value
77+
return v
78+
79+
def _sample_block_size(
80+
self,
81+
rng: np.random.RandomState,
82+
scale,
83+
aspect_ratio_scale
84+
):
85+
# -- Sample spatial block mask scale
86+
_rand = rng.random()
87+
min_s, max_s = scale
88+
spatial_mask_scale = min_s + _rand * (max_s - min_s)
89+
spatial_num_keep = int(self.height * self.width * spatial_mask_scale)
90+
91+
# -- Sample block aspect-ratio
92+
_rand = rng.random()
93+
min_ar, max_ar = aspect_ratio_scale
94+
aspect_ratio = min_ar + _rand * (max_ar - min_ar)
95+
96+
# -- Compute block height and width (given scale and aspect-ratio)
97+
h = int(round(math.sqrt(spatial_num_keep * aspect_ratio)))
98+
w = int(round(math.sqrt(spatial_num_keep / aspect_ratio)))
99+
h = min(h, self.height)
100+
w = min(w, self.width)
101+
102+
return (h, w)
103+
104+
def _sample_block_mask(self, b_size, rng: np.random.RandomState):
105+
h, w = b_size
106+
top = rng.randint(0, self.height - h + 1)
107+
left = rng.randint(0, self.width - w + 1)
108+
109+
mask = np.ones((self.height, self.width), dtype=np.int32)
110+
mask[top:top+h, left:left+w] = 0
111+
112+
return mask
113+
114+
def __call__(self, batch_size):
115+
"""
116+
Create encoder and predictor masks when collating imgs into a batch
117+
# 1. sample pred block size using seed
118+
# 2. sample several pred block locations for each image (w/o seed)
119+
# 3. return pred masks and complement (enc mask)
120+
"""
121+
seed = self.step()
122+
rng = np.random.RandomState(seed)
123+
p_size = self._sample_block_size(
124+
rng=rng,
125+
scale=self.pred_mask_scale,
126+
aspect_ratio_scale=self.aspect_ratio,
127+
)
128+
129+
collated_masks_pred, collated_masks_enc = [], []
130+
min_keep_enc = min_keep_pred = self.height * self.width
131+
for _ in range(batch_size):
132+
133+
empty_context = True
134+
while empty_context:
135+
# Create a mask for this sample
136+
mask_e = np.ones((self.height, self.width), dtype=np.int32)
137+
for _ in range(self.npred):
138+
mask_e *= self._sample_block_mask(p_size, rng)
139+
mask_e = mask_e.flatten()
140+
141+
mask_p = np.where(mask_e == 0)[0]
142+
mask_e = np.where(mask_e != 0)[0]
143+
144+
empty_context = len(mask_e) == 0
145+
if not empty_context:
146+
min_keep_pred = min(min_keep_pred, len(mask_p))
147+
min_keep_enc = min(min_keep_enc, len(mask_e))
148+
collated_masks_pred.append(mask_p)
149+
collated_masks_enc.append(mask_e)
150+
151+
if self.max_keep is not None:
152+
min_keep_enc = min(min_keep_enc, self.max_keep)
153+
154+
# Truncate arrays to the minimum length to create uniform arrays
155+
collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred]
156+
collated_masks_pred = np.array(collated_masks_pred)
157+
158+
collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc]
159+
collated_masks_enc = np.array(collated_masks_enc)
160+
161+
return collated_masks_enc, collated_masks_pred
162+
163+
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import RateCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
8+
def test_rateCell1():
9+
## create seeding keys
10+
dkey = random.PRNGKey(1234)
11+
dkey, *subkeys = random.split(dkey, 6)
12+
# in_dim = 9 # ... dimension of patch data ...
13+
# hid_dim = 9 # ... number of atoms in the dictionary matrix
14+
dt = 1. # ms
15+
T = 300 # ms # (OR) number of E-steps to take during inference
16+
# ---- build a sparse coding linear generative model with a Cauchy prior ----
17+
with Context("Circuit") as circuit:
18+
a = RateCell(name="a", n_units=1, tau_m=0.,
19+
act_fx="identity", key=subkeys[0])
20+
b = RateCell(name="b", n_units=1, tau_m=0.,
21+
act_fx="identity", key=subkeys[1])
22+
23+
# wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab`
24+
25+
# wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b`
26+
b.j << a.zF
27+
28+
## create and compile core simulation commands
29+
reset_cmd, reset_args = circuit.compile_by_key(a, b, compile_key="reset")
30+
circuit.add_command(wrap_command(jit(circuit.reset)), name="reset")
31+
32+
advance_cmd, advance_args = circuit.compile_by_key(a, b,
33+
compile_key="advance_state")
34+
circuit.add_command(wrap_command(jit(circuit.advance_state)), name="advance")
35+
36+
37+
## set up non-compiled utility commands
38+
@Context.dynamicCommand
39+
def clamp(x):
40+
a.j.set(x)
41+
42+
x_seq = jnp.asarray([[1, 1, 0, 0, 1]], dtype=jnp.float32)
43+
44+
circuit.reset()
45+
for ts in range(x_seq.shape[1]):
46+
x_t = jnp.expand_dims(x_seq[0,ts], axis=0) ## get data at time t
47+
circuit.clamp(x_t)
48+
circuit.advance(t=ts*1., dt=1.)
49+
50+
print(a.zF.value)
51+
# assertion here if needed!

0 commit comments

Comments
 (0)