Skip to content

Commit 29697f5

Browse files
committed
Cleanup SparseMultiply Ops
* Handle static shape * Rename to more readable Op classes * Simplify perform
1 parent 648e6e2 commit 29697f5

File tree

2 files changed

+81
-66
lines changed

2 files changed

+81
-66
lines changed

pytensor/sparse/math.py

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.gradient import grad_not_implemented
1313
from pytensor.graph import Apply, Op
1414
from pytensor.link.c.op import COp
15+
from pytensor.sparse import SparseTensorType
1516
from pytensor.tensor import TensorType, Variable, specify_broadcastable, tensor
1617
from pytensor.tensor.type import complex_dtypes
1718

@@ -428,7 +429,7 @@ def make_node(self, x, y):
428429
return Apply(
429430
self,
430431
[x, y],
431-
[psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()],
432+
[SparseTensorType(dtype=out_dtype, format=x.type.format)()],
432433
)
433434

434435
def perform(self, node, inputs, outputs):
@@ -488,7 +489,7 @@ def make_node(self, x, y):
488489
return Apply(
489490
self,
490491
[x, y],
491-
[psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()],
492+
[SparseTensorType(dtype=x.type.dtype, format=x.type.format)()],
492493
)
493494

494495
def perform(self, node, inputs, outputs):
@@ -591,7 +592,7 @@ def make_node(self, x, y):
591592
return Apply(
592593
self,
593594
[x, y],
594-
[psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()],
595+
[SparseTensorType(dtype=x.type.dtype, format=x.type.format)()],
595596
)
596597

597598
def perform(self, node, inputs, outputs):
@@ -707,7 +708,7 @@ def sub(x, y):
707708
sub.__doc__ = subtract.__doc__
708709

709710

710-
class MulSS(Op):
711+
class SparseSparseMultiply(Op):
711712
# mul(sparse, sparse)
712713
# See the doc of mul() for more detail
713714
__props__ = ()
@@ -720,7 +721,7 @@ def make_node(self, x, y):
720721
return Apply(
721722
self,
722723
[x, y],
723-
[psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()],
724+
[SparseTensorType(dtype=out_dtype, format=x.type.format)()],
724725
)
725726

726727
def perform(self, node, inputs, outputs):
@@ -742,10 +743,10 @@ def infer_shape(self, fgraph, node, shapes):
742743
return [shapes[0]]
743744

744745

745-
mul_s_s = MulSS()
746+
mul_s_s = SparseSparseMultiply()
746747

747748

748-
class MulSD(Op):
749+
class SparseDenseMultiply(Op):
749750
# mul(sparse, dense)
750751
# See the doc of mul() for more detail
751752
__props__ = ()
@@ -762,65 +763,63 @@ def make_node(self, x, y):
762763
# objects must be matrices (have dimension 2)
763764
# Broadcasting of the sparse matrix is not supported.
764765
# We support nd == 0 used by grad of SpSum()
765-
assert y.type.ndim in (0, 2)
766-
out = psb.SparseTensorType(dtype=dtype, format=x.type.format)()
766+
if y.type.ndim not in (0, 2):
767+
raise ValueError(f"y {y} must have 0 or 2 dimensions. Got {y.type.ndim}")
768+
if y.type.ndim == 0:
769+
out_shape = x.type.shape
770+
if y.type.ndim == 2:
771+
# Combine with static shape information from y
772+
out_shape = []
773+
for x_st_dim_length, y_st_dim_length in zip(x.type.shape, y.type.shape):
774+
if x_st_dim_length is None:
775+
out_shape.append(y_st_dim_length)
776+
else:
777+
out_shape.append(x_st_dim_length)
778+
# If both are known, they must match
779+
if (
780+
y_st_dim_length is not None
781+
and y_st_dim_length != x_st_dim_length
782+
):
783+
raise ValueError(
784+
f"Incompatible static shapes {x}: {x.type.shape}, {y}: {y.type.shape}"
785+
)
786+
out_shape = tuple(out_shape)
787+
out = SparseTensorType(dtype=dtype, format=x.type.format, shape=out_shape)()
767788
return Apply(self, [x, y], [out])
768789

769790
def perform(self, node, inputs, outputs):
770791
(x, y) = inputs
771792
(out,) = outputs
793+
out_dtype = node.outputs[0].dtype
772794
assert psb._is_sparse(x) and psb._is_dense(y)
773-
if len(y.shape) == 0:
774-
out_dtype = node.outputs[0].dtype
775-
if x.dtype == out_dtype:
776-
z = x.copy()
777-
else:
778-
z = x.astype(out_dtype)
779-
out[0] = z
780-
out[0].data *= y
781-
elif len(y.shape) == 1:
782-
raise NotImplementedError() # RowScale / ColScale
783-
elif len(y.shape) == 2:
795+
796+
if x.dtype == out_dtype:
797+
z = x.copy()
798+
else:
799+
z = x.astype(out_dtype)
800+
out[0] = z
801+
z_data = z.data
802+
803+
if y.ndim == 0:
804+
z_data *= y
805+
else: # y_ndim == 2
784806
# if we have enough memory to fit y, maybe we can fit x.asarray()
785807
# too?
786808
# TODO: change runtime from O(M*N) to O(nonzeros)
787809
M, N = x.shape
788810
assert x.shape == y.shape
789-
out_dtype = node.outputs[0].dtype
790-
811+
indices = x.indices
812+
indptr = x.indptr
791813
if x.format == "csc":
792-
indices = x.indices
793-
indptr = x.indptr
794-
if x.dtype == out_dtype:
795-
z = x.copy()
796-
else:
797-
z = x.astype(out_dtype)
798-
z_data = z.data
799-
800814
for j in range(0, N):
801815
for i_idx in range(indptr[j], indptr[j + 1]):
802816
i = indices[i_idx]
803817
z_data[i_idx] *= y[i, j]
804-
out[0] = z
805818
elif x.format == "csr":
806-
indices = x.indices
807-
indptr = x.indptr
808-
if x.dtype == out_dtype:
809-
z = x.copy()
810-
else:
811-
z = x.astype(out_dtype)
812-
z_data = z.data
813-
814819
for i in range(0, M):
815820
for j_idx in range(indptr[i], indptr[i + 1]):
816821
j = indices[j_idx]
817822
z_data[j_idx] *= y[i, j]
818-
out[0] = z
819-
else:
820-
warn(
821-
"This implementation of MulSD is deficient: {x.format}",
822-
)
823-
out[0] = type(x)(x.toarray() * y)
824823

825824
def grad(self, inputs, gout):
826825
(x, y) = inputs
@@ -833,10 +832,10 @@ def infer_shape(self, fgraph, node, shapes):
833832
return [shapes[0]]
834833

835834

836-
mul_s_d = MulSD()
835+
mul_s_d = SparseDenseMultiply()
837836

838837

839-
class MulSV(Op):
838+
class SparseDenseVectorMultiply(Op):
840839
"""Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise.
841840
842841
Notes
@@ -845,6 +844,8 @@ class MulSV(Op):
845844
846845
"""
847846

847+
# TODO: Merge with the SparseDenseMultiply Op
848+
848849
__props__ = ()
849850

850851
def make_node(self, x, y):
@@ -861,17 +862,30 @@ def make_node(self, x, y):
861862
assert x.format in ("csr", "csc")
862863
y = ptb.as_tensor_variable(y)
863864

864-
assert y.type.ndim == 1
865+
if y.type.ndim != 1:
866+
raise ValueError(f"y {y} must have 1 dimension. Got {y.type.ndim}")
865867

866868
if x.type.dtype != y.type.dtype:
867869
raise NotImplementedError(
868-
"MulSV not implemented for differing dtypes."
869-
f"Got {x.type.dtype} and {y.type.dtype}."
870+
f"Differing dtypes not supported. Got {x.type.dtype} and {y.type.dtype}."
870871
)
872+
out_shape = [x.type.shape[0]]
873+
if x.type.shape[-1] is None:
874+
out_shape.append(y.type.shape[0])
875+
else:
876+
out_shape.append(x.type.shape[-1])
877+
if y.type.shape[-1] is not None and x.type.shape[-1] != y.type.shape[-1]:
878+
raise ValueError(
879+
f"Incompatible static shapes for multiplication {x}: {x.type.shape}, {y}: {y.type.shape}"
880+
)
871881
return Apply(
872882
self,
873883
[x, y],
874-
[psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()],
884+
[
885+
SparseTensorType(
886+
dtype=x.type.dtype, format=x.type.format, shape=tuple(out_shape)
887+
)()
888+
],
875889
)
876890

877891
def perform(self, node, inputs, outputs):
@@ -901,7 +915,7 @@ def infer_shape(self, fgraph, node, ins_shapes):
901915
return [ins_shapes[0]]
902916

903917

904-
mul_s_v = MulSV()
918+
mul_s_v = SparseDenseVectorMultiply()
905919

906920

907921
def multiply(x, y):
@@ -940,16 +954,17 @@ def multiply(x, y):
940954
# mul_s_s is not implemented if the types differ
941955
if y.dtype == "float64" and x.dtype == "float32":
942956
x = x.astype("float64")
943-
944957
return mul_s_s(x, y)
945-
elif x_is_sparse_variable and not y_is_sparse_variable:
958+
elif x_is_sparse_variable or y_is_sparse_variable:
959+
if y_is_sparse_variable:
960+
x, y = y, x
946961
# mul is unimplemented if the dtypes differ
947962
if y.dtype == "float64" and x.dtype == "float32":
948963
x = x.astype("float64")
949-
950-
return mul_s_d(x, y)
951-
elif y_is_sparse_variable and not x_is_sparse_variable:
952-
return mul_s_d(y, x)
964+
if y.ndim == 1:
965+
return mul_s_v(x, y)
966+
else:
967+
return mul_s_d(x, y)
953968
else:
954969
raise NotImplementedError()
955970

@@ -999,7 +1014,7 @@ def make_node(self, x, y):
9991014
if x.type.format != y.type.format:
10001015
raise NotImplementedError()
10011016
return Apply(
1002-
self, [x, y], [psb.SparseTensorType(dtype="uint8", format=x.type.format)()]
1017+
self, [x, y], [SparseTensorType(dtype="uint8", format=x.type.format)()]
10031018
)
10041019

10051020
def perform(self, node, inputs, outputs):
@@ -1252,7 +1267,7 @@ def make_node(self, x, y):
12521267
raise NotImplementedError()
12531268

12541269
inputs = [x, y] # Need to convert? e.g. assparse
1255-
outputs = [psb.SparseTensorType(dtype=x.type.dtype, format=myformat)()]
1270+
outputs = [SparseTensorType(dtype=x.type.dtype, format=myformat)()]
12561271
return Apply(self, inputs, outputs)
12571272

12581273
def perform(self, node, inp, out_):
@@ -1373,9 +1388,7 @@ def make_node(self, a, b):
13731388
raise NotImplementedError("non-matrix b")
13741389

13751390
if psb._is_sparse_variable(b):
1376-
return Apply(
1377-
self, [a, b], [psb.SparseTensorType(a.type.format, dtype_out)()]
1378-
)
1391+
return Apply(self, [a, b], [SparseTensorType(a.type.format, dtype_out)()])
13791392
else:
13801393
return Apply(
13811394
self,
@@ -1397,7 +1410,7 @@ def perform(self, node, inputs, outputs):
13971410
)
13981411

13991412
variable = a * b
1400-
if isinstance(node.outputs[0].type, psb.SparseTensorType):
1413+
if isinstance(node.outputs[0].type, SparseTensorType):
14011414
assert psb._is_sparse(variable)
14021415
out[0] = variable
14031416
return

tests/sparse/test_rewriting.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def test_local_mul_s_d():
7878
f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode=mode)
7979

8080
assert not any(
81-
isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort()
81+
isinstance(node.op, smath.SparseDenseMultiply)
82+
for node in f.maker.fgraph.toposort()
8283
)
8384

8485

@@ -95,7 +96,8 @@ def test_local_mul_s_v():
9596
f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode=mode)
9697

9798
assert not any(
98-
isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort()
99+
isinstance(node.op, smath.SparseDenseVectorMultiply)
100+
for node in f.maker.fgraph.toposort()
99101
)
100102

101103

0 commit comments

Comments
 (0)