Skip to content

Commit 0a0f548

Browse files
committed
Fix to device in brownian utils.
1 parent 495f0ef commit 0a0f548

File tree

4 files changed

+51
-22
lines changed

4 files changed

+51
-22
lines changed

tests/test_brownian_tree.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,37 +30,38 @@
3030
torch.set_default_dtype(torch.float64)
3131

3232
D = 3
33-
BATCH_SIZE = 16384
33+
SMALL_BATCH_SIZE = 16
34+
LARGE_BATCH_SIZE = 16384
3435
REPS = 3
3536
ALPHA = 0.001
3637

3738

3839
class TestBrownianTree(TorchTestCase):
3940

40-
def _setUp(self, device=None):
41+
def _setUp(self, batch_size, device=None):
4142
t0, t1 = torch.tensor([0., 1.]).to(device)
42-
w0 = torch.zeros(BATCH_SIZE, D).to(device=device)
43-
w1 = torch.randn(BATCH_SIZE, D).to(device=device)
43+
w0 = torch.zeros(batch_size, D).to(device=device)
44+
w1 = torch.randn(batch_size, D).to(device=device)
4445
t = torch.rand([]).to(device)
4546

4647
self.t = t
4748
self.bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, entropy=0)
4849

4950
def test_basic_cpu(self):
50-
self._setUp(device=torch.device('cpu'))
51+
self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cpu'))
5152
sample = self.bm(self.t)
52-
self.assertEqual(sample.size(), (BATCH_SIZE, D))
53+
self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D))
5354

5455
def test_basic_gpu(self):
5556
if not torch.cuda.is_available():
5657
self.skipTest(reason='CUDA not available.')
5758

58-
self._setUp(device=torch.device('cuda'))
59+
self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cuda'))
5960
sample = self.bm(self.t)
60-
self.assertEqual(sample.size(), (BATCH_SIZE, D))
61+
self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D))
6162

6263
def test_determinism(self):
63-
self._setUp()
64+
self._setUp(batch_size=SMALL_BATCH_SIZE)
6465
vals = [self.bm(self.t) for _ in range(REPS)]
6566
for val in vals[1:]:
6667
self.tensorAssertAllClose(val, vals[0])
@@ -73,8 +74,8 @@ def test_normality(self):
7374
for _ in range(REPS):
7475
w0_, w1_ = 0.0, npr.randn()
7576
# Use the same endpoint for the batch, so samples from same dist.
76-
w0 = torch.tensor(w0_).repeat(BATCH_SIZE)
77-
w1 = torch.tensor(w1_).repeat(BATCH_SIZE)
77+
w0 = torch.tensor(w0_).repeat(LARGE_BATCH_SIZE)
78+
w1 = torch.tensor(w1_).repeat(LARGE_BATCH_SIZE)
7879
bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, pool_size=100, tol=1e-14)
7980

8081
for _ in range(REPS):
@@ -89,6 +90,21 @@ def test_normality(self):
8990
_, pval = kstest(samples_, ref_dist.cdf)
9091
self.assertGreaterEqual(pval, ALPHA)
9192

93+
def test_to(self):
94+
if not torch.cuda.is_available():
95+
self.skipTest(reason='CUDA not available.')
96+
97+
self._setUp(batch_size=SMALL_BATCH_SIZE)
98+
cache = self.bm.get_cache()
99+
old = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0)
100+
101+
gpu = torch.device('cuda')
102+
self.bm.to(gpu)
103+
cache = self.bm.get_cache()
104+
new = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0)
105+
self.assertTrue(str(new.device).startswith('cuda'))
106+
self.tensorAssertAllClose(old, new.cpu())
107+
92108

93109
if __name__ == '__main__':
94110
unittest.main()

torchsde/brownian/brownian_path.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,7 @@ def __repr__(self):
134134
)
135135

136136
def to(self, *args, **kwargs):
137-
ws_new = blist.blist()
138-
for w in self._ws:
139-
ws_new.append(w.to(*args, **kwargs))
140-
self._ws = ws_new
137+
self._ws = utils.blist_to(self._ws, *args, **kwargs)
141138

142139
@property
143140
def dtype(self):
@@ -153,3 +150,9 @@ def size(self):
153150

154151
def __len__(self):
155152
return len(self._ts)
153+
154+
def get_cache(self):
155+
return {
156+
'ts': self._ts,
157+
'ws': self._ws,
158+
}

torchsde/brownian/brownian_tree.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def __repr__(self):
138138
)
139139

140140
def to(self, *args, **kwargs):
141-
self._ws_prev = _list_to(self._ws_prev, *args, **kwargs)
142-
self._ws_post = _list_to(self._ws_post, *args, **kwargs)
143-
self._ws = _list_to(self._ws, *args, **kwargs)
141+
self._ws_prev = utils.blist_to(self._ws_prev, *args, **kwargs)
142+
self._ws_post = utils.blist_to(self._ws_post, *args, **kwargs)
143+
self._ws = utils.blist_to(self._ws, *args, **kwargs)
144144

145145
@property
146146
def dtype(self):
@@ -157,6 +157,16 @@ def size(self):
157157
def __len__(self):
158158
return len(self._ts) + len(self._ts_prev) + len(self._ts_post)
159159

160+
def get_cache(self):
161+
return {
162+
'ts_prev': self._ts_prev,
163+
'ts': self._ts,
164+
'ts_post': self._ts_post,
165+
'ws_prev': self._ws_prev,
166+
'ws': self._ws,
167+
'ws_post': self._ws_post
168+
}
169+
160170

161171
def _binary_search(t0, t1, w0, w1, t, parent, tol):
162172
seedv, seedl, seedr = parent.spawn(3)
@@ -211,7 +221,3 @@ def _create_cache(t0, t1, w0, w1, entropy, pool_size, k):
211221
seeds = new_seeds
212222

213223
return ts, ws, seeds
214-
215-
216-
def _list_to(l, *args, **kwargs):
217-
return [li.to(*args, **kwargs) for li in l]

torchsde/brownian/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,7 @@ def brownian_bridge(t0: float, t1: float, w0, w1, t: float, seed=None):
9494

9595
def is_scalar(x):
9696
return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1)
97+
98+
99+
def blist_to(l, *args, **kwargs):
100+
return blist.blist([li.to(*args, **kwargs) for li in l])

0 commit comments

Comments
 (0)