3030torch .set_default_dtype (torch .float64 )
3131
3232D = 3
33- BATCH_SIZE = 16384
33+ SMALL_BATCH_SIZE = 16
34+ LARGE_BATCH_SIZE = 16384
3435REPS = 3
3536ALPHA = 0.001
3637
3738
3839class TestBrownianTree (TorchTestCase ):
3940
40- def _setUp (self , device = None ):
41+ def _setUp (self , batch_size , device = None ):
4142 t0 , t1 = torch .tensor ([0. , 1. ]).to (device )
42- w0 = torch .zeros (BATCH_SIZE , D ).to (device = device )
43- w1 = torch .randn (BATCH_SIZE , D ).to (device = device )
43+ w0 = torch .zeros (batch_size , D ).to (device = device )
44+ w1 = torch .randn (batch_size , D ).to (device = device )
4445 t = torch .rand ([]).to (device )
4546
4647 self .t = t
4748 self .bm = BrownianTree (t0 = t0 , t1 = t1 , w0 = w0 , w1 = w1 , entropy = 0 )
4849
4950 def test_basic_cpu (self ):
50- self ._setUp (device = torch .device ('cpu' ))
51+ self ._setUp (batch_size = SMALL_BATCH_SIZE , device = torch .device ('cpu' ))
5152 sample = self .bm (self .t )
52- self .assertEqual (sample .size (), (BATCH_SIZE , D ))
53+ self .assertEqual (sample .size (), (SMALL_BATCH_SIZE , D ))
5354
5455 def test_basic_gpu (self ):
5556 if not torch .cuda .is_available ():
5657 self .skipTest (reason = 'CUDA not available.' )
5758
58- self ._setUp (device = torch .device ('cuda' ))
59+ self ._setUp (batch_size = SMALL_BATCH_SIZE , device = torch .device ('cuda' ))
5960 sample = self .bm (self .t )
60- self .assertEqual (sample .size (), (BATCH_SIZE , D ))
61+ self .assertEqual (sample .size (), (SMALL_BATCH_SIZE , D ))
6162
6263 def test_determinism (self ):
63- self ._setUp ()
64+ self ._setUp (batch_size = SMALL_BATCH_SIZE )
6465 vals = [self .bm (self .t ) for _ in range (REPS )]
6566 for val in vals [1 :]:
6667 self .tensorAssertAllClose (val , vals [0 ])
@@ -73,8 +74,8 @@ def test_normality(self):
7374 for _ in range (REPS ):
7475 w0_ , w1_ = 0.0 , npr .randn ()
7576 # Use the same endpoint for the batch, so samples from same dist.
76- w0 = torch .tensor (w0_ ).repeat (BATCH_SIZE )
77- w1 = torch .tensor (w1_ ).repeat (BATCH_SIZE )
77+ w0 = torch .tensor (w0_ ).repeat (LARGE_BATCH_SIZE )
78+ w1 = torch .tensor (w1_ ).repeat (LARGE_BATCH_SIZE )
7879 bm = BrownianTree (t0 = t0 , t1 = t1 , w0 = w0 , w1 = w1 , pool_size = 100 , tol = 1e-14 )
7980
8081 for _ in range (REPS ):
@@ -89,6 +90,21 @@ def test_normality(self):
8990 _ , pval = kstest (samples_ , ref_dist .cdf )
9091 self .assertGreaterEqual (pval , ALPHA )
9192
93+ def test_to (self ):
94+ if not torch .cuda .is_available ():
95+ self .skipTest (reason = 'CUDA not available.' )
96+
97+ self ._setUp (batch_size = SMALL_BATCH_SIZE )
98+ cache = self .bm .get_cache ()
99+ old = torch .cat (list (cache ['ws_prev' ]) + list (cache ['ws' ]) + list (cache ['ws_post' ]), dim = 0 )
100+
101+ gpu = torch .device ('cuda' )
102+ self .bm .to (gpu )
103+ cache = self .bm .get_cache ()
104+ new = torch .cat (list (cache ['ws_prev' ]) + list (cache ['ws' ]) + list (cache ['ws_post' ]), dim = 0 )
105+ self .assertTrue (str (new .device ).startswith ('cuda' ))
106+ self .tensorAssertAllClose (old , new .cpu ())
107+
92108
93109if __name__ == '__main__' :
94110 unittest .main ()
0 commit comments