Skip to content

Commit 84f9e7b

Browse files
committed
Handle static shape in core sparse methods
1 parent ec07492 commit 84f9e7b

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

pytensor/sparse/basic.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,19 +207,19 @@ def sp_zeros_like(x):
207207

208208

209209
# for more dtypes, call SparseTensorType(format, dtype)
210-
def matrix(format, name=None, dtype=None):
210+
def matrix(format, name=None, dtype=None, shape=None):
211211
if dtype is None:
212212
dtype = config.floatX
213-
type = SparseTensorType(format=format, dtype=dtype)
213+
type = SparseTensorType(format=format, dtype=dtype, shape=shape)
214214
return type(name)
215215

216216

217-
def csc_matrix(name=None, dtype=None):
218-
return matrix("csc", name, dtype)
217+
def csc_matrix(name=None, dtype=None, shape=None):
218+
return matrix("csc", name=name, dtype=dtype, shape=shape)
219219

220220

221-
def csr_matrix(name=None, dtype=None):
222-
return matrix("csr", name, dtype)
221+
def csr_matrix(name=None, dtype=None, shape=None):
222+
return matrix("csr", name=name, dtype=dtype, shape=shape)
223223

224224

225225
def bsr_matrix(name=None, dtype=None):
@@ -434,10 +434,22 @@ def make_node(self, data, indices, indptr, shape):
434434
if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes:
435435
raise TypeError("n_rows must be integer type", shape, shape.type)
436436

437+
static_shape = (None, None)
438+
if (
439+
shape.owner is not None
440+
and isinstance(shape.owner.op, CSMProperties)
441+
and shape.owner.outputs[3] is shape
442+
):
443+
static_shape = shape.owner.inputs[0].type.shape
444+
437445
return Apply(
438446
self,
439447
[data, indices, indptr, shape],
440-
[SparseTensorType(dtype=data.type.dtype, format=self.format)()],
448+
[
449+
SparseTensorType(
450+
dtype=data.type.dtype, format=self.format, shape=static_shape
451+
)()
452+
],
441453
)
442454

443455
def perform(self, node, inputs, outputs):
@@ -698,7 +710,7 @@ def make_node(self, x):
698710
return Apply(
699711
self,
700712
[x],
701-
[TensorType(dtype=x.type.dtype, shape=(None, None))()],
713+
[TensorType(dtype=x.type.dtype, shape=x.type.shape)()],
702714
)
703715

704716
def perform(self, node, inputs, outputs):

pytensor/sparse/variable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def sum(self, axis=None, sparse_grad=False):
127127
def toarray(self):
128128
return dense_from_sparse(self)
129129

130+
todense = toarray
131+
130132
@property
131133
def shape(self):
132134
# TODO: The plan is that the ShapeFeature in ptb.opt will do shape

0 commit comments

Comments
 (0)