@@ -47,12 +47,40 @@ std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rul
47
47
return std::make_tuple (result, 0 );
48
48
}
49
49
50
+ std::tuple<Tensor,optional<int64_t >> randn_like_batch_rule (
51
+ const Tensor& self, optional<int64_t > self_bdim,
52
+ c10::optional<ScalarType> dtype,
53
+ c10::optional<Layout> layout,
54
+ c10::optional<Device> device,
55
+ c10::optional<bool > pin_memory,
56
+ c10::optional<c10::MemoryFormat> optional_memory_format) {
57
+ // Disable the random key
58
+ c10::impl::ExcludeDispatchKeyGuard guard (kVmapModeKey );
59
+ return std::make_tuple (
60
+ at::randn_like (self, dtype, layout, device, pin_memory, optional_memory_format),
61
+ self_bdim);
62
+ }
63
+
64
+ std::tuple<Tensor,optional<int64_t >> rand_like_batch_rule (
65
+ const Tensor& self, optional<int64_t > self_bdim,
66
+ c10::optional<ScalarType> dtype,
67
+ c10::optional<Layout> layout,
68
+ c10::optional<Device> device,
69
+ c10::optional<bool > pin_memory,
70
+ c10::optional<c10::MemoryFormat> optional_memory_format) {
71
+ // Disable the random key
72
+ c10::impl::ExcludeDispatchKeyGuard guard (kVmapModeKey );
73
+ return std::make_tuple (
74
+ at::rand_like (self, dtype, layout, device, pin_memory, optional_memory_format),
75
+ self_bdim);
76
+ }
77
+
50
78
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
51
79
VMAP_SUPPORT (" ones_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (ones_like)));
52
80
VMAP_SUPPORT (" zeros_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (zeros_like)));
53
81
VMAP_SUPPORT (" empty_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (empty_like)));
54
- VMAP_SUPPORT (" randn_like" , BASIC_UNARY_BATCH_RULE ( ATEN_FN (randn_like)) );
55
- VMAP_SUPPORT (" rand_like" , BASIC_UNARY_BATCH_RULE ( ATEN_FN (rand_like)) );
82
+ VMAP_SUPPORT (" randn_like" , randn_like_batch_rule );
83
+ VMAP_SUPPORT (" rand_like" , rand_like_batch_rule );
56
84
VMAP_SUPPORT (" full_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (full_like)));
57
85
VMAP_SUPPORT (" new_empty" , NEW_BLAH_BATCH_RULE (ATEN_FN (new_empty)));
58
86
VMAP_SUPPORT (" new_zeros" , NEW_BLAH_BATCH_RULE (ATEN_FN (new_zeros)));
0 commit comments