1212from pytensor .gradient import grad_not_implemented
1313from pytensor .graph import Apply , Op
1414from pytensor .link .c .op import COp
15+ from pytensor .sparse import SparseTensorType
1516from pytensor .tensor import TensorType , Variable , specify_broadcastable , tensor
1617from pytensor .tensor .type import complex_dtypes
1718
@@ -428,7 +429,7 @@ def make_node(self, x, y):
428429 return Apply (
429430 self ,
430431 [x , y ],
431- [psb . SparseTensorType (dtype = out_dtype , format = x .type .format )()],
432+ [SparseTensorType (dtype = out_dtype , format = x .type .format )()],
432433 )
433434
434435 def perform (self , node , inputs , outputs ):
@@ -488,7 +489,7 @@ def make_node(self, x, y):
488489 return Apply (
489490 self ,
490491 [x , y ],
491- [psb . SparseTensorType (dtype = x .type .dtype , format = x .type .format )()],
492+ [SparseTensorType (dtype = x .type .dtype , format = x .type .format )()],
492493 )
493494
494495 def perform (self , node , inputs , outputs ):
@@ -591,7 +592,7 @@ def make_node(self, x, y):
591592 return Apply (
592593 self ,
593594 [x , y ],
594- [psb . SparseTensorType (dtype = x .type .dtype , format = x .type .format )()],
595+ [SparseTensorType (dtype = x .type .dtype , format = x .type .format )()],
595596 )
596597
597598 def perform (self , node , inputs , outputs ):
@@ -707,7 +708,7 @@ def sub(x, y):
707708sub .__doc__ = subtract .__doc__
708709
709710
710- class MulSS (Op ):
711+ class SparseSparseMultiply (Op ):
711712 # mul(sparse, sparse)
712713 # See the doc of mul() for more detail
713714 __props__ = ()
@@ -720,7 +721,7 @@ def make_node(self, x, y):
720721 return Apply (
721722 self ,
722723 [x , y ],
723- [psb . SparseTensorType (dtype = out_dtype , format = x .type .format )()],
724+ [SparseTensorType (dtype = out_dtype , format = x .type .format )()],
724725 )
725726
726727 def perform (self , node , inputs , outputs ):
@@ -742,10 +743,10 @@ def infer_shape(self, fgraph, node, shapes):
742743 return [shapes [0 ]]
743744
744745
745- mul_s_s = MulSS ()
746+ mul_s_s = SparseSparseMultiply ()
746747
747748
748- class MulSD (Op ):
749+ class SparseDenseMultiply (Op ):
749750 # mul(sparse, dense)
750751 # See the doc of mul() for more detail
751752 __props__ = ()
@@ -762,65 +763,63 @@ def make_node(self, x, y):
762763 # objects must be matrices (have dimension 2)
763764 # Broadcasting of the sparse matrix is not supported.
764765 # We support nd == 0 used by grad of SpSum()
765- assert y .type .ndim in (0 , 2 )
766- out = psb .SparseTensorType (dtype = dtype , format = x .type .format )()
766+ if y .type .ndim not in (0 , 2 ):
767+ raise ValueError (f"y { y } must have 0 or 2 dimensions. Got { y .type .ndim } " )
768+ if y .type .ndim == 0 :
769+ out_shape = x .type .shape
770+ if y .type .ndim == 2 :
771+ # Combine with static shape information from y
772+ out_shape = []
773+ for x_st_dim_length , y_st_dim_length in zip (x .type .shape , y .type .shape ):
774+ if x_st_dim_length is None :
775+ out_shape .append (y_st_dim_length )
776+ else :
777+ out_shape .append (x_st_dim_length )
778+ # If both are known, they must match
779+ if (
780+ y_st_dim_length is not None
781+ and y_st_dim_length != x_st_dim_length
782+ ):
783+ raise ValueError (
784+ f"Incompatible static shapes { x } : { x .type .shape } , { y } : { y .type .shape } "
785+ )
786+ out_shape = tuple (out_shape )
787+ out = SparseTensorType (dtype = dtype , format = x .type .format , shape = out_shape )()
767788 return Apply (self , [x , y ], [out ])
768789
769790 def perform (self , node , inputs , outputs ):
770791 (x , y ) = inputs
771792 (out ,) = outputs
793+ out_dtype = node .outputs [0 ].dtype
772794 assert psb ._is_sparse (x ) and psb ._is_dense (y )
773- if len ( y . shape ) == 0 :
774- out_dtype = node . outputs [ 0 ]. dtype
775- if x . dtype == out_dtype :
776- z = x . copy ()
777- else :
778- z = x . astype ( out_dtype )
779- out [ 0 ] = z
780- out [ 0 ]. data *= y
781- elif len ( y . shape ) == 1 :
782- raise NotImplementedError () # RowScale / ColScale
783- elif len ( y . shape ) == 2 :
795+
796+ if x . dtype == out_dtype :
797+ z = x . copy ()
798+ else :
799+ z = x . astype ( out_dtype )
800+ out [ 0 ] = z
801+ z_data = z . data
802+
803+ if y . ndim == 0 :
804+ z_data *= y
805+ else : # y_ndim == 2
784806 # if we have enough memory to fit y, maybe we can fit x.asarray()
785807 # too?
786808 # TODO: change runtime from O(M*N) to O(nonzeros)
787809 M , N = x .shape
788810 assert x .shape == y .shape
789- out_dtype = node . outputs [ 0 ]. dtype
790-
811+ indices = x . indices
812+ indptr = x . indptr
791813 if x .format == "csc" :
792- indices = x .indices
793- indptr = x .indptr
794- if x .dtype == out_dtype :
795- z = x .copy ()
796- else :
797- z = x .astype (out_dtype )
798- z_data = z .data
799-
800814 for j in range (0 , N ):
801815 for i_idx in range (indptr [j ], indptr [j + 1 ]):
802816 i = indices [i_idx ]
803817 z_data [i_idx ] *= y [i , j ]
804- out [0 ] = z
805818 elif x .format == "csr" :
806- indices = x .indices
807- indptr = x .indptr
808- if x .dtype == out_dtype :
809- z = x .copy ()
810- else :
811- z = x .astype (out_dtype )
812- z_data = z .data
813-
814819 for i in range (0 , M ):
815820 for j_idx in range (indptr [i ], indptr [i + 1 ]):
816821 j = indices [j_idx ]
817822 z_data [j_idx ] *= y [i , j ]
818- out [0 ] = z
819- else :
820- warn (
821- "This implementation of MulSD is deficient: {x.format}" ,
822- )
823- out [0 ] = type (x )(x .toarray () * y )
824823
825824 def grad (self , inputs , gout ):
826825 (x , y ) = inputs
@@ -833,10 +832,10 @@ def infer_shape(self, fgraph, node, shapes):
833832 return [shapes [0 ]]
834833
835834
836- mul_s_d = MulSD ()
835+ mul_s_d = SparseDenseMultiply ()
837836
838837
839- class MulSV (Op ):
838+ class SparseDenseVectorMultiply (Op ):
840839 """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise.
841840
842841 Notes
@@ -845,6 +844,8 @@ class MulSV(Op):
845844
846845 """
847846
847+ # TODO: Merge with the SparseDenseMultiply Op
848+
848849 __props__ = ()
849850
850851 def make_node (self , x , y ):
@@ -861,17 +862,30 @@ def make_node(self, x, y):
861862 assert x .format in ("csr" , "csc" )
862863 y = ptb .as_tensor_variable (y )
863864
864- assert y .type .ndim == 1
865+ if y .type .ndim != 1 :
866+ raise ValueError (f"y { y } must have 1 dimension. Got { y .type .ndim } " )
865867
866868 if x .type .dtype != y .type .dtype :
867869 raise NotImplementedError (
868- "MulSV not implemented for differing dtypes."
869- f"Got { x .type .dtype } and { y .type .dtype } ."
870+ f"Differing dtypes not supported. Got { x .type .dtype } and { y .type .dtype } ."
870871 )
872+ out_shape = [x .type .shape [0 ]]
873+ if x .type .shape [- 1 ] is None :
874+ out_shape .append (y .type .shape [0 ])
875+ else :
876+ out_shape .append (x .type .shape [- 1 ])
877+ if y .type .shape [- 1 ] is not None and x .type .shape [- 1 ] != y .type .shape [- 1 ]:
878+ raise ValueError (
879+ f"Incompatible static shapes for multiplication { x } : { x .type .shape } , { y } : { y .type .shape } "
880+ )
871881 return Apply (
872882 self ,
873883 [x , y ],
874- [psb .SparseTensorType (dtype = x .type .dtype , format = x .type .format )()],
884+ [
885+ SparseTensorType (
886+ dtype = x .type .dtype , format = x .type .format , shape = tuple (out_shape )
887+ )()
888+ ],
875889 )
876890
877891 def perform (self , node , inputs , outputs ):
@@ -901,7 +915,7 @@ def infer_shape(self, fgraph, node, ins_shapes):
901915 return [ins_shapes [0 ]]
902916
903917
904- mul_s_v = MulSV ()
918+ mul_s_v = SparseDenseVectorMultiply ()
905919
906920
907921def multiply (x , y ):
@@ -940,16 +954,17 @@ def multiply(x, y):
940954 # mul_s_s is not implemented if the types differ
941955 if y .dtype == "float64" and x .dtype == "float32" :
942956 x = x .astype ("float64" )
943-
944957 return mul_s_s (x , y )
945- elif x_is_sparse_variable and not y_is_sparse_variable :
958+ elif x_is_sparse_variable or y_is_sparse_variable :
959+ if y_is_sparse_variable :
960+ x , y = y , x
946961 # mul is unimplemented if the dtypes differ
947962 if y .dtype == "float64" and x .dtype == "float32" :
948963 x = x .astype ("float64" )
949-
950- return mul_s_d (x , y )
951- elif y_is_sparse_variable and not x_is_sparse_variable :
952- return mul_s_d (y , x )
964+ if y . ndim == 1 :
965+ return mul_s_v (x , y )
966+ else :
967+ return mul_s_d (x , y )
953968 else :
954969 raise NotImplementedError ()
955970
@@ -999,7 +1014,7 @@ def make_node(self, x, y):
9991014 if x .type .format != y .type .format :
10001015 raise NotImplementedError ()
10011016 return Apply (
1002- self , [x , y ], [psb . SparseTensorType (dtype = "uint8" , format = x .type .format )()]
1017+ self , [x , y ], [SparseTensorType (dtype = "uint8" , format = x .type .format )()]
10031018 )
10041019
10051020 def perform (self , node , inputs , outputs ):
@@ -1252,7 +1267,7 @@ def make_node(self, x, y):
12521267 raise NotImplementedError ()
12531268
12541269 inputs = [x , y ] # Need to convert? e.g. assparse
1255- outputs = [psb . SparseTensorType (dtype = x .type .dtype , format = myformat )()]
1270+ outputs = [SparseTensorType (dtype = x .type .dtype , format = myformat )()]
12561271 return Apply (self , inputs , outputs )
12571272
12581273 def perform (self , node , inp , out_ ):
@@ -1373,9 +1388,7 @@ def make_node(self, a, b):
13731388 raise NotImplementedError ("non-matrix b" )
13741389
13751390 if psb ._is_sparse_variable (b ):
1376- return Apply (
1377- self , [a , b ], [psb .SparseTensorType (a .type .format , dtype_out )()]
1378- )
1391+ return Apply (self , [a , b ], [SparseTensorType (a .type .format , dtype_out )()])
13791392 else :
13801393 return Apply (
13811394 self ,
@@ -1397,7 +1410,7 @@ def perform(self, node, inputs, outputs):
13971410 )
13981411
13991412 variable = a * b
1400- if isinstance (node .outputs [0 ].type , psb . SparseTensorType ):
1413+ if isinstance (node .outputs [0 ].type , SparseTensorType ):
14011414 assert psb ._is_sparse (variable )
14021415 out [0 ] = variable
14031416 return
0 commit comments