Skip to content

Commit 7a08b52

Browse files
committed
* add model avg util
1 parent dd564be commit 7a08b52

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from functools import partial
2+
3+
import jax
4+
import tensorflow_probability.substrates.jax as tfp
5+
6+
from dsa2000_fm.actors.average_utils import average_rule
7+
8+
tfpd = tfp.distributions
9+
10+
11+
@partial(jax.jit, static_argnames=['Tm', 'Cm'])
12+
def average_model(vis_model, Tm: int, Cm: int):
13+
"""
14+
Averages model vis data.
15+
16+
Args:
17+
vis_model: [D, T, B, C, ...]
18+
Tm: number of model times
19+
Cm: number of model channels
20+
21+
Returns:
22+
[D, Tm, B, Cm, ...]
23+
"""
24+
# average data to match model: [Ts, B, Cs[, 2, 2]] -> [Tm, B, Cm[, 2, 2]]
25+
if Tm is not None:
26+
time_average_rule = partial(
27+
average_rule,
28+
num_model_size=Tm,
29+
axis=1
30+
)
31+
else:
32+
time_average_rule = lambda x: x
33+
if Cm is not None:
34+
freq_average_rule = partial(
35+
average_rule,
36+
num_model_size=Cm,
37+
axis=3
38+
)
39+
else:
40+
freq_average_rule = lambda x: x
41+
vis_model_avg = time_average_rule(freq_average_rule(vis_model))
42+
return vis_model_avg
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
3+
from dsa2000_cal.model_utils import average_model
4+
5+
6+
def test_average_model():
7+
D,T,B,C = 3,4,5,6
8+
Tm = 2
9+
Cm = 2
10+
assert np.shape(average_model(np.ones((D,T,B,C)), Tm=Tm, Cm=Cm)) == (D,Tm,B,Cm)
11+
12+
Tm = 1
13+
Cm = 1
14+
assert np.shape(average_model(np.ones((D, T, B, C)), Tm=Tm, Cm=Cm)) == (D, Tm, B, Cm)

0 commit comments

Comments
 (0)