@@ -207,19 +207,19 @@ def sp_zeros_like(x):
207207
208208
209209# for more dtypes, call SparseTensorType(format, dtype)
210- def matrix (format , name = None , dtype = None ):
210+ def matrix (format , name = None , dtype = None , shape = None ):
211211 if dtype is None :
212212 dtype = config .floatX
213- type = SparseTensorType (format = format , dtype = dtype )
213+ type = SparseTensorType (format = format , dtype = dtype , shape = shape )
214214 return type (name )
215215
216216
217- def csc_matrix (name = None , dtype = None ):
218- return matrix ("csc" , name , dtype )
217+ def csc_matrix (name = None , dtype = None , shape = None ):
218+ return matrix ("csc" , name = name , dtype = dtype , shape = shape )
219219
220220
221- def csr_matrix (name = None , dtype = None ):
222- return matrix ("csr" , name , dtype )
221+ def csr_matrix (name = None , dtype = None , shape = None ):
222+ return matrix ("csr" , name = name , dtype = dtype , shape = shape )
223223
224224
225225def bsr_matrix (name = None , dtype = None ):
@@ -434,10 +434,22 @@ def make_node(self, data, indices, indptr, shape):
434434 if shape .type .ndim != 1 or shape .type .dtype not in discrete_dtypes :
435435 raise TypeError ("n_rows must be integer type" , shape , shape .type )
436436
437+ static_shape = (None , None )
438+ if (
439+ shape .owner is not None
440+ and isinstance (shape .owner .op , CSMProperties )
441+ and shape .owner .outputs [3 ] is shape
442+ ):
443+ static_shape = shape .owner .inputs [0 ].type .shape
444+
437445 return Apply (
438446 self ,
439447 [data , indices , indptr , shape ],
440- [SparseTensorType (dtype = data .type .dtype , format = self .format )()],
448+ [
449+ SparseTensorType (
450+ dtype = data .type .dtype , format = self .format , shape = static_shape
451+ )()
452+ ],
441453 )
442454
443455 def perform (self , node , inputs , outputs ):
@@ -698,7 +710,7 @@ def make_node(self, x):
698710 return Apply (
699711 self ,
700712 [x ],
701- [TensorType (dtype = x .type .dtype , shape = ( None , None ) )()],
713+ [TensorType (dtype = x .type .dtype , shape = x . type . shape )()],
702714 )
703715
704716 def perform (self , node , inputs , outputs ):
0 commit comments