Skip to content

Commit 1dca562

Browse files
Balandatfacebook-github-bot
authored andcommitted
Properly generate tensors in test_distributions.py (#619)
Summary: Gets rid of a bunch of UserWarnings that were introduced in this test due to improper tensor generation. Pull Request resolved: #619 Reviewed By: qingfeng10 Differential Revision: D25386701 Pulled By: Balandat fbshipit-source-id: e766e93ddd58df1c953f9b8016224f0915fbcf9d
1 parent 46dec82 commit 1dca562

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

test/distributions/test_distributions.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ def assertTensorsEqual(a, b):
136136

137137
class TestKumaraswamy(BotorchTestCase, TestCase):
138138
def test_kumaraswamy_shape(self):
139-
concentration1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
140-
concentration0 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
141-
concentration1_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True)
142-
concentration0_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True)
139+
concentration1 = torch.randn(2, 3).abs().requires_grad_(True)
140+
concentration0 = torch.randn(2, 3).abs().requires_grad_(True)
141+
concentration1_1d = torch.randn(1).abs().requires_grad_(True)
142+
concentration0_1d = torch.randn(1).abs().requires_grad_(True)
143143
self.assertEqual(
144144
Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3)
145145
)
@@ -159,10 +159,10 @@ def test_kumaraswamy_shape(self):
159159
# Kumaraswamy distribution is not implemented in SciPy
160160
# Hence these tests are explicit
161161
def test_kumaraswamy_mean_variance(self):
162-
c1_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
163-
c0_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
164-
c1_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True)
165-
c0_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True)
162+
c1_1 = torch.randn(2, 3).abs().requires_grad_(True)
163+
c0_1 = torch.randn(2, 3).abs().requires_grad_(True)
164+
c1_2 = torch.randn(4).abs().requires_grad_(True)
165+
c0_2 = torch.randn(4).abs().requires_grad_(True)
166166
cases = [(c1_1, c0_1), (c1_2, c0_2)]
167167
for i, (a, b) in enumerate(cases):
168168
m = Kumaraswamy(a, b)
@@ -472,7 +472,7 @@ def test_cdf_log_prob(self):
472472
for Dist, params in EXAMPLES:
473473
for i, param in enumerate(params):
474474
dist = Dist(**param)
475-
samples = torch.tensor(dist.sample())
475+
samples = dist.sample().clone().detach()
476476
if samples.dtype.is_floating_point:
477477
samples.requires_grad_()
478478
try:

0 commit comments

Comments
 (0)