Skip to content

Commit 7c0d901

Browse files
Updates
1 parent 9c15478 commit 7c0d901

File tree

2 files changed

+1
-32
lines changed

2 files changed

+1
-32
lines changed

test/test_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4641,4 +4641,4 @@ def test_index_select_stacked_not_supported(self):
46414641

46424642
if __name__ == "__main__":
46434643
args, unknown = argparse.ArgumentParser().parse_known_args()
4644-
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4644+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/tensor_specs.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6826,37 +6826,6 @@ def _stack_specs(list_of_spec, dim=0, out=None):
68266826
else:
68276827
raise NotImplementedError
68286828

6829-
@TensorSpec.implements_for_spec(torch.index_select)
6830-
@Composite.implements_for_spec(torch.index_select)
6831-
def _index_select_spec(input: TensorSpec, dim: int, index: torch.Tensor) -> TensorSpec:
6832-
dim = dim % len(input.shape) if dim < 0 else dim
6833-
6834-
# Validate index bounds
6835-
if torch.any(index < 0) or torch.any(index >= input.shape[dim]):
6836-
raise IndexError(f"index {index} is out of bounds for dimension {dim} with size {input.shape[dim]}")
6837-
6838-
if isinstance(input, Composite):
6839-
new_shape = list(input.shape)
6840-
new_shape[dim] = index.numel()
6841-
new_specs = {}
6842-
for key, spec in input.items():
6843-
if spec is not None:
6844-
new_specs[key] = torch.index_select(spec, dim, index)
6845-
else:
6846-
new_specs[key] = None
6847-
return Composite(new_specs, shape=torch.Size(new_shape), device=input.device)
6848-
else:
6849-
new_shape = list(input.shape)
6850-
new_shape[dim] = index.numel()
6851-
if isinstance(input, (OneHot, MultiOneHot, Binary)) and dim == len(input.shape) - 1:
6852-
raise ValueError(f"Cannot index_select along the last dimension of {type(input).__name__} spec, as it represents the domain.")
6853-
if isinstance(input, Bounded):
6854-
new_low = torch.index_select(input.space.low, dim, index)
6855-
new_high = torch.index_select(input.space.high, dim, index)
6856-
return input.__class__(low=new_low, high=new_high, shape=torch.Size(new_shape), device=input.device, dtype=input.dtype)
6857-
else:
6858-
return input._reshape(torch.Size(new_shape))
6859-
68606829
@TensorSpec.implements_for_spec(torch.index_select)
68616830
@Composite.implements_for_spec(torch.index_select)
68626831
def _index_select_spec(input: TensorSpec, dim: int, index: torch.Tensor) -> TensorSpec:

0 commit comments

Comments
 (0)