Skip to content

Commit 6887581

Browse files
r-barnesfacebook-github-bot
authored andcommitted
Annotate some functions that return None
Summary: Test functions return None. This codemod fixes that so type annotation efforts can focus on trickier cases. Reviewed By: azad-meta Differential Revision: D52570248 fbshipit-source-id: b20e5ec6cde1132d4e1f954af1e012d8464343c8
1 parent da05f77 commit 6887581

File tree

7 files changed

+31
-31
lines changed

7 files changed

+31
-31
lines changed

examples/mnist_lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_step(self, batch, batch_idx):
136136
self.log("test_accuracy", self.test_accuracy, on_step=False, on_epoch=True)
137137
return loss
138138

139-
def on_train_epoch_end(self):
139+
def on_train_epoch_end(self) -> None:
140140
# Logging privacy spent: (epsilon, delta)
141141
epsilon = self.privacy_engine.get_epsilon(self.delta)
142142
self.log("epsilon", epsilon, on_epoch=True, prog_bar=True)

opacus/tests/accountants_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
class AccountingTest(unittest.TestCase):
30-
def test_rdp_accountant(self):
30+
def test_rdp_accountant(self) -> None:
3131
noise_multiplier = 1.5
3232
sample_rate = 0.04
3333
steps = int(90 / 0.04)
@@ -39,7 +39,7 @@ def test_rdp_accountant(self):
3939
epsilon = accountant.get_epsilon(delta=1e-5)
4040
self.assertAlmostEqual(epsilon, 7.32911117143)
4141

42-
def test_gdp_accountant(self):
42+
def test_gdp_accountant(self) -> None:
4343
noise_multiplier = 1.5
4444
sample_rate = 0.04
4545
steps = int(90 // 0.04)
@@ -52,7 +52,7 @@ def test_gdp_accountant(self):
5252
self.assertLess(6.59, epsilon)
5353
self.assertLess(epsilon, 6.6)
5454

55-
def test_prv_accountant(self):
55+
def test_prv_accountant(self) -> None:
5656
noise_multiplier = 1.5
5757
sample_rate = 0.04
5858
steps = int(90 // 0.04)
@@ -65,7 +65,7 @@ def test_prv_accountant(self):
6565
epsilon = accountant.get_epsilon(delta=1e-5)
6666
self.assertAlmostEqual(epsilon, 6.777395712150674)
6767

68-
def test_get_noise_multiplier_rdp_epochs(self):
68+
def test_get_noise_multiplier_rdp_epochs(self) -> None:
6969
delta = 1e-5
7070
sample_rate = 0.04
7171
epsilon = 8
@@ -81,7 +81,7 @@ def test_get_noise_multiplier_rdp_epochs(self):
8181

8282
self.assertAlmostEqual(noise_multiplier, 1.416, places=4)
8383

84-
def test_get_noise_multiplier_rdp_steps(self):
84+
def test_get_noise_multiplier_rdp_steps(self) -> None:
8585
delta = 1e-5
8686
sample_rate = 0.04
8787
epsilon = 8
@@ -96,7 +96,7 @@ def test_get_noise_multiplier_rdp_steps(self):
9696

9797
self.assertAlmostEqual(noise_multiplier, 1.3562, places=4)
9898

99-
def test_get_noise_multiplier_prv_epochs(self):
99+
def test_get_noise_multiplier_prv_epochs(self) -> None:
100100
delta = 1e-5
101101
sample_rate = 0.04
102102
epsilon = 8
@@ -112,7 +112,7 @@ def test_get_noise_multiplier_prv_epochs(self):
112112

113113
self.assertAlmostEqual(noise_multiplier, 1.34765625, places=4)
114114

115-
def test_get_noise_multiplier_prv_steps(self):
115+
def test_get_noise_multiplier_prv_steps(self) -> None:
116116
delta = 1e-5
117117
sample_rate = 0.04
118118
epsilon = 8
@@ -153,7 +153,7 @@ def test_get_noise_multiplier_overshoot(self, epsilon, epochs, sample_rate, delt
153153
actual_epsilon = accountant.get_epsilon(delta=delta)
154154
self.assertLess(actual_epsilon, epsilon)
155155

156-
def test_get_noise_multiplier_gdp(self):
156+
def test_get_noise_multiplier_gdp(self) -> None:
157157
delta = 1e-5
158158
sample_rate = 0.04
159159
epsilon = 8
@@ -169,7 +169,7 @@ def test_get_noise_multiplier_gdp(self):
169169

170170
self.assertAlmostEqual(noise_multiplier, 1.3232421875)
171171

172-
def test_accountant_state_dict(self):
172+
def test_accountant_state_dict(self) -> None:
173173
noise_multiplier = 1.5
174174
sample_rate = 0.04
175175
steps = int(90 / 0.04)
@@ -191,7 +191,7 @@ def test_accountant_state_dict(self):
191191
accountant.state_dict(dummy_dest)["dummy_k"], dummy_dest["dummy_k"]
192192
)
193193

194-
def test_accountant_load_state_dict(self):
194+
def test_accountant_load_state_dict(self) -> None:
195195
noise_multiplier = 1.5
196196
sample_rate = 0.04
197197
steps = int(90 / 0.04)

opacus/tests/batch_memory_manager_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_empty_batch(
171171
)
172172
weights_before = torch.clone(model._module.fc.weight)
173173

174-
def test_equivalent_to_one_batch(self):
174+
def test_equivalent_to_one_batch(self) -> None:
175175
torch.manual_seed(1337)
176176
model, optimizer, data_loader = self._init_training()
177177

@@ -229,7 +229,7 @@ def test_equivalent_to_one_batch(self):
229229
class BatchMemoryManagerTestWithExpandedWeights(BatchMemoryManagerTest):
230230
GSM_MODE = "ew"
231231

232-
def test_empty_batch(self):
232+
def test_empty_batch(self) -> None:
233233
pass
234234

235235

opacus/tests/ddp_hook_check.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,15 @@ def run_function(local_function, tensor, dp, noise_multiplier=0, max_grad_norm=1
255255

256256

257257
class GradientComputationTest(unittest.TestCase):
258-
def test_connection(self):
258+
def test_connection(self) -> None:
259259
tensor = torch.zeros(10, 10)
260260
world_size = run_function(debug, tensor, dp=True)
261261

262262
self.assertTrue(
263263
world_size >= 2, f"Need at least 2 gpus but was provided only {world_size}."
264264
)
265265

266-
def test_gradient_noclip_zeronoise(self):
266+
def test_gradient_noclip_zeronoise(self) -> None:
267267
# Tests that gradient is the same with DP or with DDP
268268
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
269269

@@ -272,7 +272,7 @@ def test_gradient_noclip_zeronoise(self):
272272

273273
self.assertTrue(torch.norm(weight_dp - weight_nodp) < 1e-7)
274274

275-
def test_ddp_hook(self):
275+
def test_ddp_hook(self) -> None:
276276
# Tests that the DDP hook does the same thing as naive aggregation with per layer clipping
277277
weight_ddp_naive, weight_ddp_hook = torch.zeros(10, 10), torch.zeros(10, 10)
278278

@@ -297,7 +297,7 @@ def test_ddp_hook(self):
297297
f"DDP naive: {weight_ddp_naive}\nDDP hook: {weight_ddp_hook}",
298298
)
299299

300-
def test_add_remove_ddp_hooks(self):
300+
def test_add_remove_ddp_hooks(self) -> None:
301301
remaining_hooks = {
302302
"attached": None,
303303
"detached": None,

opacus/tests/distributed_poisson_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def setUp(self) -> None:
5454

5555
self.samplers, self.dataloaders = self._init_data(seed=7)
5656

57-
def test_length(self):
57+
def test_length(self) -> None:
5858
for sampler in self.samplers:
5959
self.assertEqual(len(sampler), 10)
6060
for dataloader in self.dataloaders:
6161
self.assertEqual(len(dataloader), 10)
6262

63-
def test_batch_sizes(self):
63+
def test_batch_sizes(self) -> None:
6464
for dataloader in self.dataloaders:
6565
batch_sizes = []
6666
for x, _y in dataloader:
@@ -71,7 +71,7 @@ def test_batch_sizes(self):
7171
np.mean(batch_sizes), self.batch_size // self.world_size, delta=2
7272
)
7373

74-
def test_separate_batches(self):
74+
def test_separate_batches(self) -> None:
7575
indices = {
7676
rank: [i.item() for batch in self.samplers[rank] for i in batch]
7777
for rank in range(self.world_size)

opacus/tests/dpdataloader_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020

2121

2222
class DPDataLoaderTest(unittest.TestCase):
23-
def setUp(self):
23+
def setUp(self) -> None:
2424
self.data_size = 10
2525
self.dimension = 7
2626
self.num_classes = 11
2727

28-
def test_collate_classes(self):
28+
def test_collate_classes(self) -> None:
2929
x = torch.randn(self.data_size, self.dimension)
3030
y = torch.randint(low=0, high=self.num_classes, size=(self.data_size,))
3131

@@ -36,7 +36,7 @@ def test_collate_classes(self):
3636
self.assertEqual(x_b.size(0), 0)
3737
self.assertEqual(y_b.size(0), 0)
3838

39-
def test_collate_tensor(self):
39+
def test_collate_tensor(self) -> None:
4040
x = torch.randn(self.data_size, self.dimension)
4141

4242
dataset = TensorDataset(x)
@@ -46,7 +46,7 @@ def test_collate_tensor(self):
4646

4747
self.assertEqual(s.size(0), 0)
4848

49-
def test_drop_last_true(self):
49+
def test_drop_last_true(self) -> None:
5050
x = torch.randn(self.data_size, self.dimension)
5151

5252
dataset = TensorDataset(x)

opacus/tests/grad_sample_module_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,26 +225,26 @@ def forward(self, x: torch.Tensor):
225225
register_grad_sampler(SimpleLinear)(compute_linear_grad_sample)
226226
GradSampleModule(SimpleLinear(4, 2))
227227

228-
def test_custom_module_validation(self):
228+
def test_custom_module_validation(self) -> None:
229229
with self.assertRaises(NotImplementedError):
230230
GradSampleModule(mobilenet_v3_small())
231231

232-
def test_submodule_access(self):
232+
def test_submodule_access(self) -> None:
233233
_ = self.grad_sample_module.fc1
234234
_ = self.grad_sample_module.fc2
235235

236236
with self.assertRaises(AttributeError):
237237
_ = self.grad_sample_module.fc3
238238

239-
def test_state_dict(self):
239+
def test_state_dict(self) -> None:
240240
gs_state_dict = self.grad_sample_module.state_dict()
241241
og_state_dict = self.original_model.state_dict()
242242
# check wrapped module state dict
243243
for key in og_state_dict.keys():
244244
self.assertTrue(f"_module.{key}" in gs_state_dict)
245245
assert_close(og_state_dict[key], gs_state_dict[f"_module.{key}"])
246246

247-
def test_load_state_dict(self):
247+
def test_load_state_dict(self) -> None:
248248
gs_state_dict = self.grad_sample_module.state_dict()
249249
new_gs = GradSampleModule(
250250
SampleConvNet(), batch_first=False, loss_reduction="mean"
@@ -261,11 +261,11 @@ def test_load_state_dict(self):
261261
class EWGradSampleModuleTest(GradSampleModuleTest):
262262
CLS = GradSampleModuleExpandedWeights
263263

264-
def test_remove_hooks(self):
264+
def test_remove_hooks(self) -> None:
265265
pass
266266

267-
def test_enable_hooks(self):
267+
def test_enable_hooks(self) -> None:
268268
pass
269269

270-
def test_disable_hooks(self):
270+
def test_disable_hooks(self) -> None:
271271
pass

0 commit comments

Comments
 (0)