Skip to content

Conversation

@ricardoV94
Copy link
Member

Related to #1806 #1827

Fix bug when passing simple Tensor shape to split_dims
Change grad_undefined -> grad_disconnected for split_sizes in SplitOp (see #1827 for more context)

):
# All elements already have the right number of dimensions, so we
# can just join them directly.
return join(0, *x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't equivalent to stack below?

Copy link
Member

@jessegrabowski jessegrabowski Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because stack adds a dimension. This was causing a bug in split_dims where we ask explicitly ask for ndims=1, passing a sequence of 1d tensors, but then we get back a 2d tensor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so my understand is this function is supposed to do what np.array(x) would do. I think the ndim is more of an assert, it should fail when the output of np.array (in our case the symbolic equivalent) would yield something different. So in that sense join is never valid as it keeps the same dimensions.

I want to revert and check if I'm missing something with the test that was failing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. From my perspective the biggest issue is that as_tensor_variable(..., ndims=1) isn't idempotent -- sequential calls on the same input keep mutating the same graph. This is happening because of stack.

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's odd because if it's already a single tensor variable (and not a list with one in it) it shouldn't do anything

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that first one seems wrong.

Even if fix it, I think our check for "sequence" on split_dims (or wherever the problem was) should be more like if isinstance(x, Sequence) or (isinstance(x, TensorVariable) and x.ndim == 1)

1d numpy arrays should also be valid, but maybe those pass the Sequence instance check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove the ndim argument altogether? numpy doesn't have it and I don't think we need it.

I thought it was just used for validation but it seems to affect non-raising outcomes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove the ndim argument altogether? numpy doesn't have it and I don't think we need it.

I thought it was just used for validation but it seems to affect non-raising outcomes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm +1 for removing it. I never knew it existed, and it seems like it's overloading the function.

If I had to guess though, it's exactly for this situation. We have an argument with type int | Variable | tuple[int | Variable]. The Variable, though, can be either a scalar or an array. So really the typing is something like int | Variable[ndim=0] | Variable[ndim=1] | tuple[int | Variable[ndim=0]. When we do the if not isinstance(shape, tuple): shape = (shape, ) we're ignoring the Variable[ndim=1] case. Calling as_tensor_variable(tuple[Variable[ndim=0]) -> Variable[ndim=1] makes sense to me, and matches the numpy behavior. In this case we're counting on the ndim=1 arugment to guard against the case of as_tensor_variable(tuple[Variable[ndim=1]) -> Variable[ndim=2].

Typing all this out, it seems like an abuse of the as_tensor_variable function.

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed. Would be really nice to be able to have those TensorVariable[ndim=0] types btw. Need to nerdsnipe some type hint lovers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants