@@ -73,11 +73,11 @@ def setUp(self):
73
73
torch .manual_seed (42 )
74
74
75
75
@abc .abstractmethod
76
- def _init_data (self ):
76
+ def _init_data (self ) -> None :
77
77
pass
78
78
79
79
@abc .abstractmethod
80
- def _init_model (self ):
80
+ def _init_model (self ) -> None :
81
81
pass
82
82
83
83
def _init_vanilla_training (
@@ -193,7 +193,7 @@ def closure():
193
193
if max_steps and steps >= max_steps :
194
194
break
195
195
196
- def test_basic (self ):
196
+ def test_basic (self ) -> None :
197
197
for opt_exclude_frozen in [True , False ]:
198
198
with self .subTest (opt_exclude_frozen = opt_exclude_frozen ):
199
199
model , optimizer , dl , _ = self ._init_private_training (
@@ -287,7 +287,7 @@ def test_compare_to_vanilla(
287
287
max_steps = max_steps ,
288
288
)
289
289
290
- def test_flat_clipping (self ):
290
+ def test_flat_clipping (self ) -> None :
291
291
self .BATCH_SIZE = 1
292
292
max_grad_norm = 0.5
293
293
@@ -314,7 +314,7 @@ def test_flat_clipping(self):
314
314
self .assertAlmostEqual (clipped_grads .norm ().item (), max_grad_norm , places = 3 )
315
315
self .assertGreater (non_clipped_grads .norm (), clipped_grads .norm ())
316
316
317
- def test_per_layer_clipping (self ):
317
+ def test_per_layer_clipping (self ) -> None :
318
318
self .BATCH_SIZE = 1
319
319
max_grad_norm_per_layer = 1.0
320
320
@@ -344,7 +344,7 @@ def test_per_layer_clipping(self):
344
344
min (non_clipped_norm , max_grad_norm_per_layer ), clipped_norm , places = 3
345
345
)
346
346
347
- def test_sample_grad_aggregation (self ):
347
+ def test_sample_grad_aggregation (self ) -> None :
348
348
"""
349
349
Check if final gradient is indeed an aggregation over per-sample gradients
350
350
"""
@@ -367,7 +367,7 @@ def test_sample_grad_aggregation(self):
367
367
f"Param: { p_name } " ,
368
368
)
369
369
370
- def test_noise_changes_every_time (self ):
370
+ def test_noise_changes_every_time (self ) -> None :
371
371
"""
372
372
Test that adding noise results in ever different model params.
373
373
We disable clipping in this test by setting it to a very high threshold.
@@ -387,7 +387,7 @@ def test_noise_changes_every_time(self):
387
387
for p0 , p1 in zip (first_run_params , second_run_params ):
388
388
self .assertFalse (torch .allclose (p0 , p1 ))
389
389
390
- def test_get_compatible_module_inaction (self ):
390
+ def test_get_compatible_module_inaction (self ) -> None :
391
391
needs_no_replacement_module = nn .Linear (1 , 2 )
392
392
fixed_module = PrivacyEngine .get_compatible_module (needs_no_replacement_module )
393
393
self .assertFalse (fixed_module is needs_no_replacement_module )
@@ -397,7 +397,7 @@ def test_get_compatible_module_inaction(self):
397
397
)
398
398
)
399
399
400
- def test_model_validator (self ):
400
+ def test_model_validator (self ) -> None :
401
401
"""
402
402
Test that the privacy engine raises errors
403
403
if there are unsupported modules
@@ -416,7 +416,7 @@ def test_model_validator(self):
416
416
grad_sample_mode = self .GRAD_SAMPLE_MODE ,
417
417
)
418
418
419
- def test_model_validator_after_fix (self ):
419
+ def test_model_validator_after_fix (self ) -> None :
420
420
"""
421
421
Test that the privacy engine fixes unsupported modules
422
422
and succeeds.
@@ -435,7 +435,7 @@ def test_model_validator_after_fix(self):
435
435
)
436
436
self .assertTrue (1 , 1 )
437
437
438
- def test_make_private_with_epsilon (self ):
438
+ def test_make_private_with_epsilon (self ) -> None :
439
439
model , optimizer , dl = self ._init_vanilla_training ()
440
440
target_eps = 2.0
441
441
target_delta = 1e-5
@@ -458,7 +458,7 @@ def test_make_private_with_epsilon(self):
458
458
target_eps , privacy_engine .get_epsilon (target_delta ), places = 2
459
459
)
460
460
461
- def test_deterministic_run (self ):
461
+ def test_deterministic_run (self ) -> None :
462
462
"""
463
463
Tests that for 2 different models, secure seed can be fixed
464
464
to produce same (deterministic) runs.
@@ -483,7 +483,7 @@ def test_deterministic_run(self):
483
483
"Model parameters after deterministic run must match" ,
484
484
)
485
485
486
- def test_validator_weight_update_check (self ):
486
+ def test_validator_weight_update_check (self ) -> None :
487
487
"""
488
488
Test that the privacy engine raises error if ModuleValidator.fix(model) is
489
489
called after the optimizer is created
@@ -522,7 +522,7 @@ def test_validator_weight_update_check(self):
522
522
grad_sample_mode = self .GRAD_SAMPLE_MODE ,
523
523
)
524
524
525
- def test_parameters_match (self ):
525
+ def test_parameters_match (self ) -> None :
526
526
dl = self ._init_data ()
527
527
528
528
m1 = self ._init_model ()
@@ -721,7 +721,7 @@ def helper_test_noise_level(
721
721
722
722
@unittest .skip ("requires torchcsprng compatible with new pytorch versions" )
723
723
@patch ("torch.normal" , MagicMock (return_value = torch .Tensor ([0.6 ])))
724
- def test_generate_noise_in_secure_mode (self ):
724
+ def test_generate_noise_in_secure_mode (self ) -> None :
725
725
"""
726
726
Tests that the noise is added correctly in secure_mode,
727
727
according to section 5.1 in https://arxiv.org/abs/2107.10138.
@@ -803,16 +803,16 @@ def _init_model(self):
803
803
804
804
805
805
class PrivacyEngineConvNetEmptyBatchTest (PrivacyEngineConvNetTest ):
806
- def setUp (self ):
806
+ def setUp (self ) -> None :
807
807
super ().setUp ()
808
808
809
809
# This will trigger multiple empty batches with poisson sampling enabled
810
810
self .BATCH_SIZE = 1
811
811
812
- def test_checkpoints (self ):
812
+ def test_checkpoints (self ) -> None :
813
813
pass
814
814
815
- def test_noise_level (self ):
815
+ def test_noise_level (self ) -> None :
816
816
pass
817
817
818
818
@@ -837,23 +837,23 @@ def _init_model(self):
837
837
838
838
839
839
class PrivacyEngineConvNetFrozenTestFunctorch (PrivacyEngineConvNetFrozenTest ):
840
- def setUp (self ):
840
+ def setUp (self ) -> None :
841
841
super ().setUp ()
842
842
self .GRAD_SAMPLE_MODE = "functorch"
843
843
844
844
845
845
class PrivacyEngineConvNetTestExpandedWeights (PrivacyEngineConvNetTest ):
846
- def setUp (self ):
846
+ def setUp (self ) -> None :
847
847
super ().setUp ()
848
848
self .GRAD_SAMPLE_MODE = "ew"
849
849
850
850
@unittest .skip ("Original p.grad is not available in ExpandedWeights" )
851
- def test_sample_grad_aggregation (self ):
851
+ def test_sample_grad_aggregation (self ) -> None :
852
852
pass
853
853
854
854
855
855
class PrivacyEngineConvNetTestFunctorch (PrivacyEngineConvNetTest ):
856
- def setUp (self ):
856
+ def setUp (self ) -> None :
857
857
super ().setUp ()
858
858
self .GRAD_SAMPLE_MODE = "functorch"
859
859
@@ -938,7 +938,7 @@ def _init_model(
938
938
939
939
940
940
class PrivacyEngineTextTestFunctorch (PrivacyEngineTextTest ):
941
- def setUp (self ):
941
+ def setUp (self ) -> None :
942
942
super ().setUp ()
943
943
self .GRAD_SAMPLE_MODE = "functorch"
944
944
@@ -987,7 +987,7 @@ def _init_model(self):
987
987
988
988
989
989
class PrivacyEngineTiedWeightsTestFunctorch (PrivacyEngineTiedWeightsTest ):
990
- def setUp (self ):
990
+ def setUp (self ) -> None :
991
991
super ().setUp ()
992
992
self .GRAD_SAMPLE_MODE = "functorch"
993
993
0 commit comments