Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4122,6 +4122,154 @@
return out


def vecdot(
x1: TensorLike,
x2: TensorLike,
dtype: Optional["DTypeLike"] = None,
) -> TensorVariable:
"""Compute the vector dot product of two arrays.
Parameters
----------
x1, x2
Input arrays with the same shape.
dtype
The desired data-type for the result. If not given, then the type will
be determined as the minimum type required to hold the objects in the
sequence.
Returns
-------
TensorVariable
The vector dot product of the inputs.
Notes
-----
This is equivalent to `numpy.vecdot` and computes the dot product of
vectors along the last axis of both inputs. Broadcasting is supported
across all other dimensions.
Examples
--------
>>> import pytensor.tensor as pt
>>> # Vector dot product with shape (5,) inputs
>>> x = pt.vector("x", shape=(5,)) # shape (5,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It really has no sense of brevity, the comment is completely superfluous. Not blocking on this, but I doubt any human would be this "meh"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, claude 3.7 is known for being verbose and over-eager.

>>> y = pt.vector("y", shape=(5,)) # shape (5,)
>>> z = pt.vecdot(x, y) # scalar output
>>> # Equivalent to numpy.vecdot(x, y)
>>>
>>> # With batched inputs of shape (3, 5)
>>> x_batch = pt.matrix("x", shape=(3, 5)) # shape (3, 5)
>>> y_batch = pt.matrix("y", shape=(3, 5)) # shape (3, 5)
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
"""
out = _inner_prod(x1, x2)

if dtype is not None:
out = out.astype(dtype)

return out


def matvec(
x1: TensorLike, x2: TensorLike, dtype: Optional["DTypeLike"] = None
) -> TensorVariable:
"""Compute the matrix-vector product.
Parameters
----------
x1
Input array for the matrix with shape (..., M, K).
x2
Input array for the vector with shape (..., K).
dtype
The desired data-type for the result. If not given, then the type will
be determined as the minimum type required to hold the objects in the
sequence.
Returns
-------
TensorVariable
The matrix-vector product with shape (..., M).
Notes
-----
This is equivalent to `numpy.matvec` and computes the matrix-vector product
with broadcasting over batch dimensions.
Examples
--------
>>> import pytensor.tensor as pt
>>> # Matrix-vector product
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
>>> v = pt.vector("v", shape=(4,)) # shape (4,)
>>> result = pt.matvec(A, v) # shape (3,)
>>> # Equivalent to numpy.matvec(A, v)
>>>
>>> # Batched matrix-vector product
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
>>> batched_v = pt.matrix("v", shape=(2, 4)) # shape (2, 4)
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
"""
out = _matrix_vec_prod(x1, x2)

if dtype is not None:
out = out.astype(dtype)

Check warning on line 4219 in pytensor/tensor/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/math.py#L4219

Added line #L4219 was not covered by tests

return out


def vecmat(
x1: TensorLike, x2: TensorLike, dtype: Optional["DTypeLike"] = None
) -> TensorVariable:
"""Compute the vector-matrix product.
Parameters
----------
x1
Input array for the vector with shape (..., K).
x2
Input array for the matrix with shape (..., K, N).
dtype
The desired data-type for the result. If not given, then the type will
be determined as the minimum type required to hold the objects in the
sequence.
Returns
-------
TensorVariable
The vector-matrix product with shape (..., N).
Notes
-----
This is equivalent to `numpy.vecmat` and computes the vector-matrix product
with broadcasting over batch dimensions.
Examples
--------
>>> import pytensor.tensor as pt
>>> # Vector-matrix product
>>> v = pt.vector("v", shape=(3,)) # shape (3,)
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
>>> result = pt.vecmat(v, A) # shape (4,)
>>> # Equivalent to numpy.vecmat(v, A)
>>>
>>> # Batched vector-matrix product
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3)
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
"""
out = _vec_matrix_prod(x1, x2)

if dtype is not None:
out = out.astype(dtype)

Check warning on line 4268 in pytensor/tensor/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/math.py#L4268

Added line #L4268 was not covered by tests

return out


@_vectorize_node.register(Dot)
def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
Expand Down Expand Up @@ -4218,6 +4366,9 @@
"max_and_argmax",
"max",
"matmul",
"vecdot",
"matvec",
"vecmat",
"argmax",
"min",
"argmin",
Expand Down
68 changes: 68 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
logaddexp,
logsumexp,
matmul,
matvec,
max,
max_and_argmax,
maximum,
Expand Down Expand Up @@ -123,6 +124,8 @@
true_div,
trunc,
var,
vecdot,
vecmat,
)
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.type import (
Expand Down Expand Up @@ -2076,6 +2079,71 @@ def is_super_shape(var1, var2):
assert is_super_shape(y, g)


def test_matrix_vector_ops():
"""Test vecdot, matvec, and vecmat helper functions."""
rng = np.random.default_rng(seed=utt.fetch_seed())

# Create test data with batch dimension (2)
batch_size = 2
dim_k = 4 # Common dimension
dim_m = 3 # Matrix rows
dim_n = 5 # Matrix columns

# Create input tensors with appropriate shapes
# For matvec: x1(b,m,k) @ x2(b,k) -> out(b,m)
# For vecmat: x1(b,k) @ x2(b,k,n) -> out(b,n)

# Create test values using config.floatX to match PyTensor's default dtype
mat_mk_val = random(batch_size, dim_m, dim_k, rng=rng).astype(config.floatX)
mat_kn_val = random(batch_size, dim_k, dim_n, rng=rng).astype(config.floatX)
vec_k_val = random(batch_size, dim_k, rng=rng).astype(config.floatX)

# Create tensor variables with matching dtype
mat_mk = tensor(
name="mat_mk", shape=(batch_size, dim_m, dim_k), dtype=config.floatX
)
mat_kn = tensor(
name="mat_kn", shape=(batch_size, dim_k, dim_n), dtype=config.floatX
)
vec_k = tensor(name="vec_k", shape=(batch_size, dim_k), dtype=config.floatX)

# Test 1: vecdot with matching dimensions
vecdot_out = vecdot(vec_k, vec_k, dtype="int32")
vecdot_fn = function([vec_k], vecdot_out)
result = vecdot_fn(vec_k_val)

# Check dtype
assert result.dtype == np.int32

# Calculate expected manually
expected_vecdot = np.zeros((batch_size,), dtype=np.int32)
for i in range(batch_size):
expected_vecdot[i] = np.sum(vec_k_val[i] * vec_k_val[i])
np.testing.assert_allclose(result, expected_vecdot)

# Test 2: matvec - matrix-vector product
matvec_out = matvec(mat_mk, vec_k)
matvec_fn = function([mat_mk, vec_k], matvec_out)
result_matvec = matvec_fn(mat_mk_val, vec_k_val)

# Calculate expected manually
expected_matvec = np.zeros((batch_size, dim_m), dtype=config.floatX)
for i in range(batch_size):
expected_matvec[i] = np.dot(mat_mk_val[i], vec_k_val[i])
np.testing.assert_allclose(result_matvec, expected_matvec)

# Test 3: vecmat - vector-matrix product
vecmat_out = vecmat(vec_k, mat_kn)
vecmat_fn = function([vec_k, mat_kn], vecmat_out)
result_vecmat = vecmat_fn(vec_k_val, mat_kn_val)

# Calculate expected manually
expected_vecmat = np.zeros((batch_size, dim_n), dtype=config.floatX)
for i in range(batch_size):
expected_vecmat[i] = np.dot(vec_k_val[i], mat_kn_val[i])
np.testing.assert_allclose(result_vecmat, expected_vecmat)


class TestTensordot:
def TensorDot(self, axes):
# Since tensordot is no longer an op, mimic the old op signature
Expand Down