@@ -456,6 +456,20 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis, p):
456
456
assert mask_specgrams .size () == specgrams .size ()
457
457
assert (num_masked_columns < mask_param ).sum () == num_masked_columns .numel ()
458
458
459
+ @parameterized .expand (list (itertools .product ([100 ], [0.0 , 30.0 ], [2 , 3 ], [0.2 , 1.0 ])))
460
+ def test_mask_along_axis_iid_mask_value (self , mask_param , mask_value , axis , p ):
461
+ specgrams = torch .randn (4 , 2 , 1025 , 400 , dtype = self .dtype , device = self .device )
462
+ mask_value_tensor = torch .tensor (mask_value , dtype = self .dtype , device = self .device )
463
+ torch .manual_seed (0 )
464
+ # as this operation is random we need to fix the seed for results to match
465
+ mask_specgrams = F .mask_along_axis_iid (specgrams , mask_param , mask_value_tensor , axis , p = p )
466
+ torch .manual_seed (0 )
467
+ mask_specgrams_float = F .mask_along_axis_iid (specgrams , mask_param , mask_value , axis , p = p )
468
+ assert torch .allclose (
469
+ mask_specgrams , mask_specgrams_float
470
+ ), f"""Masking with float and tensor should be the same diff = {
471
+ torch .abs (mask_specgrams - mask_specgrams_float ).max ()} """
472
+
459
473
@parameterized .expand (list (itertools .product ([(2 , 1025 , 400 ), (1 , 201 , 100 )], [100 ], [0.0 , 30.0 ], [1 , 2 ])))
460
474
def test_mask_along_axis_preserve (self , shape , mask_param , mask_value , axis ):
461
475
"""mask_along_axis should not alter original input Tensor
0 commit comments