Skip to content

Commit 648e6e2

Browse files
committed
Handle static shape in CSM and sparse matrix constructors
1 parent 80afd07 commit 648e6e2

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

pytensor/sparse/basic.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,19 +207,19 @@ def sp_zeros_like(x):
207207

208208

209209
# for more dtypes, call SparseTensorType(format, dtype)
210-
def matrix(format, name=None, dtype=None):
210+
def matrix(format, name=None, dtype=None, shape=None):
211211
if dtype is None:
212212
dtype = config.floatX
213-
type = SparseTensorType(format=format, dtype=dtype)
213+
type = SparseTensorType(format=format, dtype=dtype, shape=shape)
214214
return type(name)
215215

216216

217-
def csc_matrix(name=None, dtype=None):
218-
return matrix("csc", name, dtype)
217+
def csc_matrix(name=None, dtype=None, shape=None):
218+
return matrix("csc", name=name, dtype=dtype, shape=shape)
219219

220220

221-
def csr_matrix(name=None, dtype=None):
222-
return matrix("csr", name, dtype)
221+
def csr_matrix(name=None, dtype=None, shape=None):
222+
return matrix("csr", name=name, dtype=dtype, shape=shape)
223223

224224

225225
def bsr_matrix(name=None, dtype=None):
@@ -434,10 +434,22 @@ def make_node(self, data, indices, indptr, shape):
434434
if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes:
435435
raise TypeError("n_rows must be integer type", shape, shape.type)
436436

437+
static_shape = (None, None)
438+
if (
439+
shape.owner is not None
440+
and isinstance(shape.owner.op, CSMProperties)
441+
and shape.owner.outputs[3] is shape
442+
):
443+
static_shape = shape.owner.inputs[0].type.shape
444+
437445
return Apply(
438446
self,
439447
[data, indices, indptr, shape],
440-
[SparseTensorType(dtype=data.type.dtype, format=self.format)()],
448+
[
449+
SparseTensorType(
450+
dtype=data.type.dtype, format=self.format, shape=static_shape
451+
)()
452+
],
441453
)
442454

443455
def perform(self, node, inputs, outputs):

0 commit comments

Comments
 (0)