@@ -47,40 +47,12 @@ std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rul
47
47
return std::make_tuple (result, 0 );
48
48
}
49
49
50
- std::tuple<Tensor,optional<int64_t >> randn_like_batch_rule (
51
- const Tensor& self, optional<int64_t > self_bdim,
52
- c10::optional<ScalarType> dtype,
53
- c10::optional<Layout> layout,
54
- c10::optional<Device> device,
55
- c10::optional<bool > pin_memory,
56
- c10::optional<c10::MemoryFormat> optional_memory_format) {
57
- // Disable the random key
58
- c10::impl::ExcludeDispatchKeyGuard guard (kVmapModeKey );
59
- return std::make_tuple (
60
- at::randn_like (self, dtype, layout, device, pin_memory, optional_memory_format),
61
- self_bdim);
62
- }
63
-
64
- std::tuple<Tensor,optional<int64_t >> rand_like_batch_rule (
65
- const Tensor& self, optional<int64_t > self_bdim,
66
- c10::optional<ScalarType> dtype,
67
- c10::optional<Layout> layout,
68
- c10::optional<Device> device,
69
- c10::optional<bool > pin_memory,
70
- c10::optional<c10::MemoryFormat> optional_memory_format) {
71
- // Disable the random key
72
- c10::impl::ExcludeDispatchKeyGuard guard (kVmapModeKey );
73
- return std::make_tuple (
74
- at::rand_like (self, dtype, layout, device, pin_memory, optional_memory_format),
75
- self_bdim);
76
- }
77
-
78
50
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
79
51
VMAP_SUPPORT (" ones_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (ones_like)));
80
52
VMAP_SUPPORT (" zeros_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (zeros_like)));
81
53
VMAP_SUPPORT (" empty_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (empty_like)));
82
- VMAP_SUPPORT (" randn_like" , randn_like_batch_rule );
83
- VMAP_SUPPORT (" rand_like" , rand_like_batch_rule );
54
+ VMAP_SUPPORT (" randn_like" , BASIC_UNARY_BATCH_RULE ( ATEN_FN (randn_like)) );
55
+ VMAP_SUPPORT (" rand_like" , BASIC_UNARY_BATCH_RULE ( ATEN_FN (rand_like)) );
84
56
VMAP_SUPPORT (" full_like" , BASIC_UNARY_BATCH_RULE (ATEN_FN (full_like)));
85
57
VMAP_SUPPORT (" new_empty" , NEW_BLAH_BATCH_RULE (ATEN_FN (new_empty)));
86
58
VMAP_SUPPORT (" new_zeros" , NEW_BLAH_BATCH_RULE (ATEN_FN (new_zeros)));
0 commit comments