Skip to content

Commit 10e1c9d

Browse files
author
Samantha Andow
authored
fix multinomial (#664)
1 parent 015f8b6 commit 10e1c9d

File tree

2 files changed

+100
-8
lines changed

2 files changed

+100
-8
lines changed

functorch/csrc/BatchRulesRandomness.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,39 @@ std::tuple<Tensor,Tensor> native_dropout_batching_rule(const Tensor& tensor, dou
207207
return std::make_tuple(output, mask);
208208
}
209209

210+
Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, const c10::optional<Generator> generator) {
211+
c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey);
212+
auto maybe_layer = maybeCurrentDynamicLayer();
213+
const auto cur_level = maybe_layer->layerId();
214+
215+
Tensor self_value;
216+
optional<int64_t> self_bdim;
217+
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
218+
self_value = moveBatchDimToFront(self_value, self_bdim);
219+
220+
RandomnessType randomness = maybe_layer->randomness();
221+
check_randomness(randomness, self_bdim.has_value());
222+
223+
if (randomness == RandomnessType::Different && !self_bdim) {
224+
auto shape = self_value.sizes();
225+
VmapDimVector shapeVec(1, maybe_layer->batchSize());
226+
shapeVec.reserve(shape.size() + 1);
227+
shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
228+
self_value = self_value.expand(shapeVec);
229+
}
230+
if (self_value.dim() == 3 && (self_bdim || randomness == RandomnessType::Different)) {
231+
self_value = reshape_dim_into(1, 0, self_value);
232+
}
233+
auto out = multinomial(self_value, num_samples, replacement, generator);
234+
if (randomness == RandomnessType::Same && !self_bdim) {
235+
return out;
236+
}
237+
if(self_value.dim() == 3 && self_bdim) {
238+
out = out.reshape(self.sizes());
239+
}
240+
return makeBatched(out, 0, cur_level);
241+
}
242+
210243
template <typename A, A a, typename C>
211244
struct RandomBatchRuleHelper;
212245

@@ -420,7 +453,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
420453

421454
UNARY_POINTWISE_RANDOM(_standard_gamma);
422455
UNARY_POINTWISE_RANDOM(_sample_dirichlet);
423-
UNARY_POINTWISE_RANDOM(multinomial);
456+
m.impl("multinomial", multinomial_batching_rule);
424457
UNARY_POINTWISE_RANDOM(poisson);
425458
UNARY_POINTWISE_RANDOM(bernoulli);
426459

test/test_vmap.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3970,20 +3970,13 @@ def test_random_unary_out_of_place(self, device, use_generator, randomness, batc
39703970
lambda t, _: torch.poisson(t, **kwargs),
39713971
]
39723972

3973-
# TODO(samdow): fix multinomial and readd
3974-
def flatten_op(t, ignored):
3975-
return torch.multinomial(t, 10, **kwargs)
3976-
39773973
B0 = 4
39783974
seed = 1234567
39793975
in_dims = self._in_dims(batched_input)
39803976

39813977
for op in ops:
39823978
always_batched = torch.randn(B0, device=device)
39833979
passed = self._get_image(batched_input, B0, device)
3984-
if op == flatten_op:
3985-
passed = passed.flatten(1, -1)
3986-
39873980
if randomness == 'error':
39883981
self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
39893982
return
@@ -4009,6 +4002,72 @@ def flatten_op(t, ignored):
40094002
for i in range(B0):
40104003
self.assertEqual(vmap_result[i], expected)
40114004

4005+
@parametrize('use_generator', [True, False])
4006+
@parametrize('randomness', ['error', 'same', 'different'])
4007+
@parametrize('batched_call', [True, False])
4008+
@parametrize('batched_input', ["first", "last", "none"])
4009+
def test_multinomial(self, device, use_generator, randomness, batched_call, batched_input):
4010+
def flatten_input(input, batch_call, batch_location):
4011+
if batch_call and batch_location != "none":
4012+
final_size = 3 # [B0, B, N]
4013+
elif not batch_call and batch_location == "none":
4014+
final_size = 1 # [N]
4015+
else:
4016+
final_size = 2 # [B0, N] or [B, N]
4017+
4018+
start_idx = final_size - 1
4019+
end_idx = -1
4020+
if batch_location == "last":
4021+
start_idx -= 1
4022+
end_idx -= 1 # gets to correct final size because using negative indices
4023+
4024+
ret = input.flatten(start_idx, end_idx)
4025+
assert ret.dim() == final_size
4026+
return ret
4027+
4028+
def op(input, _):
4029+
return torch.multinomial(input, 10, **kwargs)
4030+
4031+
generator = torch.Generator(device=device)
4032+
orig_state = generator.get_state()
4033+
kwargs = {'generator': generator} if use_generator else {}
4034+
4035+
B0 = 4
4036+
seed = 1234567
4037+
in_dims = self._in_dims(batched_input)
4038+
4039+
always_batched = torch.randn(B0, device=device)
4040+
passed = self._get_image(batched_input, B0, device)
4041+
passed = flatten_input(passed, batched_call, batched_input)
4042+
if randomness == 'error':
4043+
self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
4044+
return
4045+
if randomness == 'same' and batched_input != "none":
4046+
self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims)
4047+
return
4048+
4049+
generator = self._reset_random(generator, orig_state, use_generator, seed)
4050+
vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)
4051+
4052+
generator = self._reset_random(generator, orig_state, use_generator, seed)
4053+
4054+
if randomness == "different":
4055+
if batched_input == "none":
4056+
passed = passed.expand(B0, *passed.shape)
4057+
if batched_input == "last":
4058+
passed = passed.movedim(-1, 0)
4059+
orig_passed_size = passed.shape[:2] if batched_call else passed.shape[:1]
4060+
passed = passed.flatten(0, 1) if batched_call else passed
4061+
expected = op(passed, always_batched)
4062+
expected.reshape(*orig_passed_size, 10)
4063+
self._assert_all_slices_unique(vmap_result)
4064+
self.assertEqual(vmap_result, expected)
4065+
else:
4066+
expected = op(passed, always_batched)
4067+
self._assert_all_slices_equal(vmap_result)
4068+
for i in range(B0):
4069+
self.assertEqual(vmap_result[i], expected)
4070+
40124071
def test_unsupported_random(self, device):
40134072
x = torch.randn(3, device=device)
40144073
y = x.abs()

0 commit comments

Comments
 (0)