@@ -3496,8 +3496,11 @@ def _reset_random(self, generator, orig_state, use_generator, seed):
3496
3496
return generator .set_state (orig_state ) if use_generator else torch .manual_seed (seed )
3497
3497
3498
3498
def _get_image (self , batched_input , batch_size , device ):
3499
- if batched_input :
3499
+ if batched_input == "first" :
3500
3500
return torch .ones ([batch_size , 3 , 3 , 14 , 14 ], device = device )
3501
+ if batched_input == "last" :
3502
+ return torch .ones ([3 , 3 , 14 , 14 , batch_size ], device = device )
3503
+ assert batched_input == "none"
3501
3504
return torch .ones ([3 , 3 , 14 , 14 ], device = device )
3502
3505
3503
3506
def _assert_all_slices_equal (self , tensor ):
@@ -3511,22 +3514,31 @@ def _assert_all_slices_unique(self, tensor):
3511
3514
slices_equal .diagonal ().zero_ ()
3512
3515
self .assertEqual (slices_equal , torch .zeros_like (slices_equal ))
3513
3516
3514
- def _assert_throws_in_error_mode (self , fn , args , in_dims = 0 ):
3517
+ def _assert_throws_in_error_mode (self , fn , args , in_dims ):
3515
3518
with self .assertRaisesRegex (RuntimeError , r"called random operation while in randomness error mode" ):
3516
3519
vmap (fn , in_dims = in_dims , randomness = "error" )(* args )
3517
3520
3518
- def _assert_throws_in_different_mode_inplace (self , fn , args , in_dims = ( None , 0 ) ):
3521
+ def _assert_throws_in_different_mode_inplace (self , fn , args , in_dims ):
3519
3522
with self .assertRaisesRegex (RuntimeError , r"different inplace randomness on an unbatched tensor" ):
3520
3523
vmap (fn , in_dims = in_dims , randomness = "different" )(* args )
3521
3524
3522
- def _assert_throws_in_same_mode_batched (self , fn , args , in_dims = 0 ):
3525
+ def _assert_throws_in_same_mode_batched (self , fn , args , in_dims ):
3523
3526
with self .assertRaisesRegex (RuntimeError ,
3524
3527
r"Vmap does not currently support same randomness with a batched tensor input" ):
3525
3528
vmap (fn , in_dims = in_dims , randomness = "same" )(* args )
3526
3529
3527
- def _in_dims (self , * batched ):
3528
- batched = batched + (True ,) # for the always batched dummy argument
3529
- return tuple (0 if is_batched else None for is_batched in batched )
3530
+ def _in_dims (self , * batched_strings ):
3531
+
3532
+ def get_in_dim (batched_string ):
3533
+ if batched_string == "first" :
3534
+ return 0
3535
+ if batched_string == "last" :
3536
+ return - 1
3537
+ assert batched_string == "none"
3538
+ return None
3539
+
3540
+ batched_strings = batched_strings + ("first" ,) # for the always batched as first dim dummy argument
3541
+ return tuple (get_in_dim (batched_string ) for batched_string in batched_strings )
3530
3542
3531
3543
@parametrize ('randomness' , ['same' , 'different' , 'error' ])
3532
3544
@parametrize ('use_generator' , [True , False ])
@@ -3597,7 +3609,7 @@ def test_randperm(self, device, randomness, use_generator):
3597
3609
self .assertEqual (vmap_result [i ], expected )
3598
3610
3599
3611
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3600
- @parametrize ('batched_input' , [True , False ])
3612
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3601
3613
def test_dropout (self , device , randomness , batched_input ):
3602
3614
def op (t , ignored ):
3603
3615
return torch .nn .functional .dropout (torch .ones_like (t ), training = True )
@@ -3628,7 +3640,7 @@ def op(t, ignored):
3628
3640
self ._assert_all_slices_equal (vmap_result )
3629
3641
3630
3642
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3631
- @parametrize ('batched_input' , [True , False ])
3643
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3632
3644
def test_alpha_dropout (self , device , randomness , batched_input ):
3633
3645
def op (t , ignored ):
3634
3646
return torch .nn .functional .alpha_dropout (torch .ones_like (t ), training = True )
@@ -3654,7 +3666,7 @@ def op(t, ignored):
3654
3666
self ._assert_all_slices_equal (vmap_result )
3655
3667
3656
3668
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3657
- @parametrize ('batched_input' , [True , False ])
3669
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3658
3670
@parametrize ('dim' , [2 , 3 ])
3659
3671
def test_feature_dropout (self , device , randomness , batched_input , dim ):
3660
3672
def op (t , ignored ):
@@ -3665,7 +3677,8 @@ def op(t, ignored):
3665
3677
always_batched = torch .randn ((B0 ,))
3666
3678
passed = self ._get_image (batched_input , B0 , device )
3667
3679
if dim == 3 :
3668
- passed = passed .unsqueeze (- 1 )
3680
+ unsqueeze_dim = - 2 if batched_input == "last" else - 1
3681
+ passed = passed .unsqueeze (unsqueeze_dim )
3669
3682
in_dims = self ._in_dims (batched_input )
3670
3683
3671
3684
if randomness == 'error' :
@@ -3696,15 +3709,16 @@ def op(t, ignored):
3696
3709
self ._assert_all_slices_equal (vmap_result )
3697
3710
3698
3711
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3699
- @parametrize ('batched_input' , [True , False ])
3712
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3700
3713
def test_feature_alpha_dropout (self , device , randomness , batched_input ):
3701
3714
def op (t , ignored ):
3702
3715
return torch .nn .functional .feature_alpha_dropout (torch .ones_like (t ), training = True )
3703
3716
3704
3717
B0 = 4
3705
3718
always_batched = torch .randn ((B0 ,))
3706
3719
passed = self ._get_image (batched_input , B0 , device )
3707
- passed = passed .unsqueeze (- 1 )
3720
+ unsqueeze_dim = - 2 if batched_input == "last" else - 1
3721
+ passed = passed .unsqueeze (unsqueeze_dim )
3708
3722
in_dims = self ._in_dims (batched_input )
3709
3723
3710
3724
if randomness == 'error' :
@@ -3733,7 +3747,7 @@ def op(t, ignored):
3733
3747
self ._assert_all_slices_equal (vmap_result )
3734
3748
3735
3749
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3736
- @parametrize ('batched_input' , [True , False ])
3750
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3737
3751
def test_like_functions (self , device , randomness , batched_input ):
3738
3752
seed = 1234567
3739
3753
supported_ops = [
@@ -3747,7 +3761,6 @@ def test_like_functions(self, device, randomness, batched_input):
3747
3761
for op in supported_ops :
3748
3762
always_batched = torch .randn (B0 )
3749
3763
passed = self ._get_image (batched_input , B0 , device )
3750
- passed = passed .unsqueeze (- 1 )
3751
3764
in_dims = self ._in_dims (batched_input )
3752
3765
3753
3766
if randomness == 'error' :
@@ -3759,8 +3772,11 @@ def test_like_functions(self, device, randomness, batched_input):
3759
3772
vmap_result = vmap (op , randomness = randomness , in_dims = in_dims )(passed , always_batched )
3760
3773
3761
3774
torch .manual_seed (seed )
3775
+
3776
+ if batched_input == "last" :
3777
+ passed = passed .movedim (- 1 , 0 )
3762
3778
if randomness == 'different' :
3763
- if not batched_input :
3779
+ if batched_input == "none" :
3764
3780
passed = passed .expand (B0 , * passed .shape )
3765
3781
expected = op (passed , 0 )
3766
3782
@@ -3769,15 +3785,16 @@ def test_like_functions(self, device, randomness, batched_input):
3769
3785
return
3770
3786
3771
3787
assert randomness == 'same'
3772
- passed = passed if not batched_input else passed [0 ]
3788
+ if batched_input != "none" :
3789
+ passed = passed [0 ]
3773
3790
expected = op (passed , 0 )
3774
3791
self ._assert_all_slices_equal (vmap_result )
3775
3792
for i in range (B0 ):
3776
3793
self .assertEqual (expected , vmap_result [i ])
3777
3794
3778
3795
@parametrize ('use_generator' , [True , False ])
3779
3796
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3780
- @parametrize ('batched_input' , [True , False ])
3797
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3781
3798
def test_random_unary_inplace (self , device , use_generator , randomness , batched_input ):
3782
3799
generator = torch .Generator (device = device )
3783
3800
orig_state = generator .get_state ()
@@ -3786,7 +3803,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
3786
3803
lambda t , _ : t .random_ (** kwargs ),
3787
3804
lambda t , _ : t .random_ (100 , ** kwargs ),
3788
3805
lambda t , _ : t .random_ (- 5 , 100 , ** kwargs ),
3789
- lambda t , _ : t .normal_ (** kwargs ),
3806
+ # lambda t, _: t.normal_(**kwargs), TODO(samdow): fix normal_ with -1 bdim
3790
3807
lambda t , _ : t .bernoulli_ (** kwargs ),
3791
3808
lambda t , _ : t .cauchy_ (** kwargs ),
3792
3809
lambda t , _ : t .exponential_ (** kwargs ),
@@ -3807,20 +3824,22 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
3807
3824
if randomness == 'error' :
3808
3825
self ._assert_throws_in_error_mode (op , (passed , always_batched ), in_dims = in_dims )
3809
3826
return
3810
- if randomness == 'different' and not batched_input :
3811
- self ._assert_throws_in_different_mode_inplace (op , (passed , always_batched ))
3827
+ if randomness == 'different' and batched_input == "none" :
3828
+ self ._assert_throws_in_different_mode_inplace (op , (passed , always_batched ), in_dims = in_dims )
3812
3829
return
3813
3830
3814
3831
generator = self ._reset_random (generator , orig_state , use_generator , seed )
3815
3832
vmap_result = vmap (op , in_dims = in_dims , randomness = randomness )(passed , always_batched )
3816
3833
3834
+ if batched_input == "last" :
3835
+ passed_expected = passed_expected .movedim (- 1 , 0 )
3817
3836
generator = self ._reset_random (generator , orig_state , use_generator , seed )
3818
3837
if randomness == "different" :
3819
3838
expected = op (passed_expected , always_batched )
3820
3839
self ._assert_all_slices_unique (vmap_result )
3821
3840
self .assertEqual (vmap_result , expected )
3822
3841
else :
3823
- if batched_input :
3842
+ if batched_input != "none" :
3824
3843
passed_expected = passed_expected [0 ]
3825
3844
expected = op (passed_expected , always_batched )
3826
3845
self ._assert_all_slices_equal (vmap_result )
@@ -3829,8 +3848,8 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
3829
3848
3830
3849
@parametrize ('use_generator' , [True , False ])
3831
3850
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3832
- @parametrize ('batched_input' , [True , False ])
3833
- @parametrize ('batched_probability' , [True , False ])
3851
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3852
+ @parametrize ('batched_probability' , ["first" , "last" , "none" ])
3834
3853
def test_bernoulli_in_place (self , device , use_generator , randomness , batched_input , batched_probability ):
3835
3854
B0 = 4
3836
3855
seed = 1234567
@@ -3851,28 +3870,32 @@ def op(t, p, ignored):
3851
3870
if randomness == 'error' :
3852
3871
self ._assert_throws_in_error_mode (op , (input , probability , always_batched ), in_dims = in_dims )
3853
3872
return
3854
- if randomness == 'same' and batched_probability :
3873
+ if randomness == 'same' and batched_probability != "none" :
3855
3874
self ._assert_throws_in_same_mode_batched (op , (input , probability , always_batched ), in_dims = in_dims )
3856
3875
return
3857
- if not batched_input and batched_probability :
3876
+ if batched_input == "none" and batched_probability != "none" :
3858
3877
regex = r"there exists a Tensor `other` in extra_args that has more elements than `self`"
3859
3878
with self .assertRaisesRegex (RuntimeError , regex ):
3860
3879
vmap (op , in_dims = in_dims , randomness = randomness )(input , probability , always_batched )
3861
3880
return
3862
- if randomness == 'different' and not batched_input :
3881
+ if randomness == 'different' and batched_input == "none" :
3863
3882
self ._assert_throws_in_different_mode_inplace (op , (input , probability , always_batched ), in_dims = in_dims )
3864
3883
return
3865
3884
3866
3885
self ._reset_random (generator , orig_state , use_generator , seed )
3867
3886
vmap_result = vmap (op , in_dims = in_dims , randomness = randomness )(input , probability , always_batched )
3868
3887
3869
3888
self ._reset_random (generator , orig_state , use_generator , seed )
3889
+ if batched_input == "last" :
3890
+ input_expected = input_expected .movedim (- 1 , 0 )
3891
+ if batched_probability == "last" :
3892
+ probability = probability .movedim (- 1 , 0 )
3870
3893
if randomness == "different" :
3871
3894
expected = op (input_expected , probability , always_batched )
3872
3895
self ._assert_all_slices_unique (vmap_result )
3873
3896
self .assertEqual (vmap_result , expected )
3874
3897
else :
3875
- if batched_input :
3898
+ if batched_input != "none" :
3876
3899
input_expected = input_expected [0 ]
3877
3900
expected = op (input_expected , probability , always_batched )
3878
3901
self ._assert_all_slices_equal (vmap_result )
@@ -3881,15 +3904,16 @@ def op(t, p, ignored):
3881
3904
3882
3905
@parametrize ('use_generator' , [True , False ])
3883
3906
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3884
- @parametrize ('batched_input' , [True , False ])
3885
- @parametrize ('batched_other' , [True , False ])
3907
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3908
+ @parametrize ('batched_other' , ["first" , "last" , "none" ])
3886
3909
def test_random_binary_out_of_place (self , device , use_generator , randomness , batched_input , batched_other ):
3887
3910
generator = torch .Generator (device = device )
3888
3911
orig_state = generator .get_state ()
3889
3912
kwargs = {'generator' : generator } if use_generator else {}
3890
3913
ops = [
3891
3914
lambda t , o , _ : torch .normal (t , o , ** kwargs ),
3892
- lambda t , o , _ : torch .binomial (t , (o - 0.5 ), ** kwargs ),
3915
+ # TODO(samdow): fix binomial
3916
+ # lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
3893
3917
]
3894
3918
3895
3919
B0 = 4
@@ -3904,31 +3928,35 @@ def test_random_binary_out_of_place(self, device, use_generator, randomness, bat
3904
3928
if randomness == 'error' :
3905
3929
self ._assert_throws_in_error_mode (op , (input , other , always_batched ), in_dims = in_dims )
3906
3930
return
3907
- if randomness == 'same' and (batched_input or batched_other ):
3931
+ if randomness == 'same' and (batched_input != "none" or batched_other != "none" ):
3908
3932
self ._assert_throws_in_same_mode_batched (op , (input , other , always_batched ), in_dims = in_dims )
3909
3933
return
3910
3934
3911
3935
generator = self ._reset_random (generator , orig_state , use_generator , seed )
3912
3936
vmap_result = vmap (op , in_dims = in_dims , randomness = randomness )(input , other , always_batched )
3913
3937
3938
+ if batched_input == "last" :
3939
+ input = input .movedim (- 1 , 0 )
3940
+ if batched_other == "last" :
3941
+ other = other .movedim (- 1 , 0 )
3942
+
3914
3943
generator = self ._reset_random (generator , orig_state , use_generator , seed )
3915
3944
if randomness == "different" :
3916
- if not batched_input :
3945
+ if batched_input == "none" :
3917
3946
input = input .expand (B0 , * input .shape )
3918
3947
expected = op (input , other , always_batched )
3919
3948
self ._assert_all_slices_unique (vmap_result )
3920
3949
self .assertEqual (vmap_result , expected )
3921
3950
else :
3922
- if batched_input :
3923
- input = input [0 ]
3951
+ assert batched_input == "none" and batched_other == "none"
3924
3952
expected = op (input , other , always_batched )
3925
3953
self ._assert_all_slices_equal (vmap_result )
3926
3954
for i in range (B0 ):
3927
3955
self .assertEqual (vmap_result [i ], expected )
3928
3956
3929
3957
@parametrize ('use_generator' , [True , False ])
3930
3958
@parametrize ('randomness' , ['error' , 'same' , 'different' ])
3931
- @parametrize ('batched_input' , [True , False ])
3959
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
3932
3960
def test_random_unary_out_of_place (self , device , use_generator , randomness , batched_input ):
3933
3961
generator = torch .Generator (device = device )
3934
3962
orig_state = generator .get_state ()
@@ -3949,7 +3977,7 @@ def flatten_op(t, ignored):
3949
3977
3950
3978
B0 = 4
3951
3979
seed = 1234567
3952
- in_dims = 0 if batched_input else ( None , 0 )
3980
+ in_dims = self . _in_dims ( batched_input )
3953
3981
3954
3982
for op in ops :
3955
3983
always_batched = torch .randn (B0 , device = device )
@@ -3960,17 +3988,19 @@ def flatten_op(t, ignored):
3960
3988
if randomness == 'error' :
3961
3989
self ._assert_throws_in_error_mode (op , (passed , always_batched ), in_dims = in_dims )
3962
3990
return
3963
- if randomness == 'same' and batched_input :
3964
- self ._assert_throws_in_same_mode_batched (op , (passed , always_batched ))
3991
+ if randomness == 'same' and batched_input != "none" :
3992
+ self ._assert_throws_in_same_mode_batched (op , (passed , always_batched ), in_dims = in_dims )
3965
3993
return
3966
3994
3967
3995
generator = self ._reset_random (generator , orig_state , use_generator , seed )
3968
3996
vmap_result = vmap (op , in_dims = in_dims , randomness = randomness )(passed , always_batched )
3969
3997
3970
3998
generator = self ._reset_random (generator , orig_state , use_generator , seed )
3971
3999
if randomness == "different" :
3972
- if not batched_input :
4000
+ if batched_input == "none" :
3973
4001
passed = passed .expand (B0 , * passed .shape )
4002
+ if batched_input == "last" :
4003
+ passed = passed .movedim (- 1 , 0 )
3974
4004
expected = op (passed , always_batched )
3975
4005
self ._assert_all_slices_unique (vmap_result )
3976
4006
self .assertEqual (vmap_result , expected )
0 commit comments