@@ -132,49 +132,6 @@ bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
132
132
return true ;
133
133
}
134
134
135
- Tensor expand_batching_rule (const Tensor& self, IntArrayRef size, bool implicit) {
136
- if (!participatesInCurrentLevel (self)) {
137
- c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
138
- return self.expand (size, implicit);
139
- }
140
-
141
- auto self_physical = MultiBatchVmapTransform::logicalToPhysical (self);
142
- auto size_physical = self_physical.getPhysicalShape (size);
143
- auto self_physical_dim = self_physical.tensor ().dim ();
144
-
145
- TORCH_CHECK ((uint64_t )self_physical_dim <= size_physical.size (),
146
- " expand: the number of sizes provided (" , /* logical*/ size.size (), " ) " ,
147
- " must be greater or equal to the number of dimensions in the tensor (" ,
148
- /* logical dim*/ self.dim (), " )" );
149
-
150
- if ((uint64_t )self_physical_dim == size_physical.size ()) {
151
- auto result = self_physical.tensor ().expand (size_physical, implicit);
152
- return self_physical.getPhysicalToLogicalMap ().apply (result);
153
- }
154
-
155
- TORCH_INTERNAL_ASSERT ((uint64_t )self_physical_dim < size_physical.size ());
156
- // Here, we know we are expanding a (logical) tensor to a larger number
157
- // of dimensions. We have to be careful because we can't call expand directly
158
- // due to the presence of batch dimensions.
159
- //
160
- // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
161
- // The result should be a tensor of size [B0, 2, 3].
162
- // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
163
- // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
164
- // then expand.
165
- auto self_physical_size = self_physical.tensor ().sizes ();
166
- auto extra_dims = size_physical.size () - self_physical_dim;
167
- VmapDimVector view_shape (size_physical.size (), 1 );
168
- std::copy (self_physical_size.begin (),
169
- self_physical_size.begin () + self_physical.numBatchDims (),
170
- view_shape.begin ());
171
- std::copy (self_physical_size.begin () + self_physical.numBatchDims (),
172
- self_physical_size.end (),
173
- view_shape.begin () + self_physical.numBatchDims () + extra_dims);
174
- auto result = self_physical.tensor ().view (view_shape).expand (size_physical, implicit);
175
- return self_physical.getPhysicalToLogicalMap ().apply (result);
176
- }
177
-
178
135
std::vector<Tensor> chunk_batching_rule (const Tensor& self, int64_t chunks, int64_t dim) {
179
136
if (!participatesInCurrentLevel (self)) {
180
137
c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
@@ -1001,7 +958,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
1001
958
// m.impl("chunk", chunk_batching_rule);
1002
959
m.impl (" tensor_split.sections" , tensor_split_sections_batching_rule);
1003
960
m.impl (" tensor_split.indices" , tensor_split_indices_batching_rule);
1004
- m.impl (" expand" , expand_batching_rule);
1005
961
m.impl (" movedim.intlist" , movedim_batching_rule);
1006
962
m.impl (" movedim.int" , static_cast <Tensor (*)(const Tensor&,int64_t ,int64_t )>(native::movedim)); // composite wrt autograd
1007
963
// NB: static_cast because there's another variant of narrow. However, we don't
0 commit comments