@@ -3970,20 +3970,13 @@ def test_random_unary_out_of_place(self, device, use_generator, randomness, batc
3970
3970
lambda t , _ : torch .poisson (t , ** kwargs ),
3971
3971
]
3972
3972
3973
- # TODO(samdow): fix multinomial and readd
3974
- def flatten_op (t , ignored ):
3975
- return torch .multinomial (t , 10 , ** kwargs )
3976
-
3977
3973
B0 = 4
3978
3974
seed = 1234567
3979
3975
in_dims = self ._in_dims (batched_input )
3980
3976
3981
3977
for op in ops :
3982
3978
always_batched = torch .randn (B0 , device = device )
3983
3979
passed = self ._get_image (batched_input , B0 , device )
3984
- if op == flatten_op :
3985
- passed = passed .flatten (1 , - 1 )
3986
-
3987
3980
if randomness == 'error' :
3988
3981
self ._assert_throws_in_error_mode (op , (passed , always_batched ), in_dims = in_dims )
3989
3982
return
@@ -4009,6 +4002,72 @@ def flatten_op(t, ignored):
4009
4002
for i in range (B0 ):
4010
4003
self .assertEqual (vmap_result [i ], expected )
4011
4004
4005
+ @parametrize ('use_generator' , [True , False ])
4006
+ @parametrize ('randomness' , ['error' , 'same' , 'different' ])
4007
+ @parametrize ('batched_call' , [True , False ])
4008
+ @parametrize ('batched_input' , ["first" , "last" , "none" ])
4009
+ def test_multinomial (self , device , use_generator , randomness , batched_call , batched_input ):
4010
+ def flatten_input (input , batch_call , batch_location ):
4011
+ if batch_call and batch_location != "none" :
4012
+ final_size = 3 # [B0, B, N]
4013
+ elif not batch_call and batch_location == "none" :
4014
+ final_size = 1 # [N]
4015
+ else :
4016
+ final_size = 2 # [B0, N] or [B, N]
4017
+
4018
+ start_idx = final_size - 1
4019
+ end_idx = - 1
4020
+ if batch_location == "last" :
4021
+ start_idx -= 1
4022
+ end_idx -= 1 # gets to correct final size because using negative indices
4023
+
4024
+ ret = input .flatten (start_idx , end_idx )
4025
+ assert ret .dim () == final_size
4026
+ return ret
4027
+
4028
+ def op (input , _ ):
4029
+ return torch .multinomial (input , 10 , ** kwargs )
4030
+
4031
+ generator = torch .Generator (device = device )
4032
+ orig_state = generator .get_state ()
4033
+ kwargs = {'generator' : generator } if use_generator else {}
4034
+
4035
+ B0 = 4
4036
+ seed = 1234567
4037
+ in_dims = self ._in_dims (batched_input )
4038
+
4039
+ always_batched = torch .randn (B0 , device = device )
4040
+ passed = self ._get_image (batched_input , B0 , device )
4041
+ passed = flatten_input (passed , batched_call , batched_input )
4042
+ if randomness == 'error' :
4043
+ self ._assert_throws_in_error_mode (op , (passed , always_batched ), in_dims = in_dims )
4044
+ return
4045
+ if randomness == 'same' and batched_input != "none" :
4046
+ self ._assert_throws_in_same_mode_batched (op , (passed , always_batched ), in_dims = in_dims )
4047
+ return
4048
+
4049
+ generator = self ._reset_random (generator , orig_state , use_generator , seed )
4050
+ vmap_result = vmap (op , in_dims = in_dims , randomness = randomness )(passed , always_batched )
4051
+
4052
+ generator = self ._reset_random (generator , orig_state , use_generator , seed )
4053
+
4054
+ if randomness == "different" :
4055
+ if batched_input == "none" :
4056
+ passed = passed .expand (B0 , * passed .shape )
4057
+ if batched_input == "last" :
4058
+ passed = passed .movedim (- 1 , 0 )
4059
+ orig_passed_size = passed .shape [:2 ] if batched_call else passed .shape [:1 ]
4060
+ passed = passed .flatten (0 , 1 ) if batched_call else passed
4061
+ expected = op (passed , always_batched )
4062
+ expected .reshape (* orig_passed_size , 10 )
4063
+ self ._assert_all_slices_unique (vmap_result )
4064
+ self .assertEqual (vmap_result , expected )
4065
+ else :
4066
+ expected = op (passed , always_batched )
4067
+ self ._assert_all_slices_equal (vmap_result )
4068
+ for i in range (B0 ):
4069
+ self .assertEqual (vmap_result [i ], expected )
4070
+
4012
4071
def test_unsupported_random (self , device ):
4013
4072
x = torch .randn (3 , device = device )
4014
4073
y = x .abs ()
0 commit comments