Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit f99ed9e

Browse files
committed
add atleast_nd decompositions + tests
1 parent 8b47f4c commit f99ed9e

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

functorch/csrc/BatchRulesStopDecomposition.cpp renamed to functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
2525
OP_DECOMPOSE(arcsinh);
2626
OP_DECOMPOSE(arctan);
2727
OP_DECOMPOSE(arctanh);
28+
OP_DECOMPOSE(atleast_1d);
29+
OP_DECOMPOSE2(atleast_1d, Sequence);
30+
OP_DECOMPOSE(atleast_2d);
31+
OP_DECOMPOSE2(atleast_2d, Sequence);
32+
OP_DECOMPOSE(atleast_3d);
33+
OP_DECOMPOSE2(atleast_3d, Sequence);
2834
OP_DECOMPOSE(broadcast_tensors);
2935
OP_DECOMPOSE(broadcast_to);
3036
OP_DECOMPOSE(clip);

test/functorch_additional_op_db.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,27 @@ def sample_inputs_cross_entropy(self, device, dtype, requires_grad, reduction):
215215
supports_out=True))
216216

217217

218+
219+
def sample_inputs_atleast_nd(self, device, dtype, requires_grad):
220+
inps = []
221+
for i in range(5):
222+
inps.append(make_tensor(list(range(i)), device=device, dtype=dtype,
223+
requires_grad=requires_grad, low=-1, high=1))
224+
225+
sample_inputs = []
226+
for inp in inps:
227+
sample_inputs.append(SampleInput(inp))
228+
229+
sample_inputs.append(SampleInput(inps))
230+
return sample_inputs
231+
232+
for i in range(1, 4):
233+
additional_op_db.append(
234+
OpInfo(f'atleast_{i}d',
235+
aten_name="atleast_{i}d",
236+
supports_autograd=True,
237+
sample_inputs_func=sample_inputs_atleast_nd,
238+
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
239+
supports_out=False))
240+
241+

test/test_vmap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,6 +2951,7 @@ def test_diagonal(self, device):
29512951
x = torch.randn(3, 4, 5, device=device, requires_grad=True)
29522952
self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))
29532953

2954+
29542955
@allowVmapFallbackUsage
29552956
def test_unrelated_output(self, device):
29562957
B0 = 3

0 commit comments

Comments
 (0)