Skip to content

Commit a8ef423

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Fix NJT min / max backward() for non-ragged reductions (pytorch#144583)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. `value_selecting_reduction_backward()` is used in the backward for min / max, so this PR implements it for NJT. Notably, this isn't enough for reducing over the ragged dim, since that results in a dense tensor and thus NJT's torch_dispatch will not be called for this op. We need factory function support for nested ints to fix that case. Pull Request resolved: pytorch#144583 Approved by: https://github.com/soulitzer ghstack dependencies: pytorch#144582
1 parent cac10b8 commit a8ef423

File tree

4 files changed

+54
-4
lines changed

4 files changed

+54
-4
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3866,6 +3866,7 @@
38663866
device_guard: False
38673867
dispatch:
38683868
CompositeImplicitAutograd: value_selecting_reduction_backward_symint
3869+
NestedTensorCPU, NestedTensorCUDA: value_selecting_reduction_backward_nested_symint
38693870

38703871
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
38713872
variants: function, method

aten/src/ATen/native/nested/NestedTensorUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ATen/ops/_nested_tensor_strides_native.h>
1111
#include <ATen/ops/chunk_native.h>
1212
#include <ATen/ops/split_with_sizes_native.h>
13+
#include <ATen/ops/value_selecting_reduction_backward_native.h>
1314
#endif
1415

1516
namespace at::native {
@@ -166,4 +167,15 @@ std::vector<Tensor> split_with_sizes_nested(
166167
return splits;
167168
}
168169

170+
Tensor value_selecting_reduction_backward_nested_symint(
171+
const Tensor& grad,
172+
int64_t dim,
173+
const Tensor& indices,
174+
c10::SymIntArrayRef sizes,
175+
bool keepdim) {
176+
TORCH_INTERNAL_ASSERT(
177+
false, "value_selecting_reduction_backward(): expected to be implemented in Python"
178+
);
179+
}
180+
169181
} // namespace at::native

test/test_nestedtensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8317,15 +8317,16 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
83178317
op_match_fn=lambda device, op: (op.full_name == "narrow"),
83188318
name="broken_narrow_backward",
83198319
),
8320-
# min / max: need to examine backwards formula for non-full reduction
8320+
# min / max: need factory function support for ragged dim reductions
8321+
# where the output is dense but sizes still contain a nested int
83218322
XFailRule(
83228323
error_type=RuntimeError,
83238324
error_msg="SymIntArrayRef expected to contain only concrete integers",
83248325
op_match_fn=lambda device, op: (
83258326
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
83268327
),
8327-
sample_match_fn=lambda device, sample: ("full reduction" not in sample.name),
8328-
name="broken_min_max_reduction_with_dim_backward",
8328+
sample_match_fn=lambda device, sample: ("ragged dim" in sample.name),
8329+
name="broken_min_max_reduction_with_dim_backward_on_ragged_dim",
83298330
),
83308331
# matmul(): unimplemented backward
83318332
XFailRule(
@@ -8496,7 +8497,7 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
84968497
op_match_fn=lambda device, op: (
84978498
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
84988499
),
8499-
sample_match_fn=lambda device, sample: ("full reduction" not in sample.name),
8500+
sample_match_fn=lambda device, sample: ("ragged dim" in sample.name),
85008501
name="broken_min_max_compile_backward",
85018502
),
85028503
# to() fails with data-dependent guards OR Unknown layout in record_stream_any_impl;

torch/nested/_internal/ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,42 @@ def argmax_default(func, *args, **kwargs):
20142014
return _apply_reduction(func, "argmax", dtype_min, *args, **kwargs)
20152015

20162016

2017+
@register_jagged_func(
2018+
torch.ops.aten.value_selecting_reduction_backward.default,
2019+
"grad: jt_all, dim: any, indices: jt_all, sizes: any, keepdim: any",
2020+
)
2021+
def value_selecting_reduction_backward_default(func, *args, **kwargs):
2022+
from torch.fx.experimental.symbolic_shapes import is_nested_int
2023+
2024+
_, new_kwargs = normalize_function( # type: ignore[misc]
2025+
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
2026+
)
2027+
2028+
grad = new_kwargs.pop("grad")
2029+
new_kwargs["grad"] = grad._values
2030+
indices = new_kwargs.pop("indices")
2031+
new_kwargs["indices"] = indices._values
2032+
# should always succeed; sizes should contain a nested int
2033+
ragged_idx = next(i for i, s in enumerate(new_kwargs["sizes"]) if is_nested_int(s))
2034+
# convert dim -> values-space dim
2035+
new_kwargs["dim"] = _wrap_jagged_dim(
2036+
len(new_kwargs["sizes"]),
2037+
new_kwargs["dim"],
2038+
ragged_idx,
2039+
"value_selecting_reduction_backward",
2040+
)
2041+
# convert saved NJT sizes -> values-space sizes
2042+
sizes = new_kwargs.pop("sizes")
2043+
sizes[ragged_idx] = indices._values.size(indices._ragged_idx - 1)
2044+
sizes = sizes[1:]
2045+
new_kwargs["sizes"] = sizes
2046+
2047+
output_kwargs = extract_kwargs(indices)
2048+
output_kwargs["_ragged_idx"] = ragged_idx
2049+
2050+
return NestedTensor(func(**new_kwargs), **output_kwargs)
2051+
2052+
20172053
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
20182054
def stack_default(func, *args, **kwargs):
20192055
_, new_kwargs = normalize_function( # type: ignore[misc]

0 commit comments

Comments
 (0)