Skip to content

Commit bec03eb

Browse files
Fix static shapes of outputs in TopKOp
1 parent eace7f6 commit bec03eb

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

aesara/tensor/sort.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,13 @@ def make_node(self, inp, kth):
414414
_check_tensor_is_scalar(kth)
415415
outs = []
416416
if self.return_values:
417-
outs.append(inp.type())
417+
outs.append(
418+
TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)()
419+
)
418420
if self.return_indices:
419-
outs.append(TensorType(dtype=self.idx_dtype, shape=inp.type.shape)())
421+
outs.append(
422+
TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)()
423+
)
420424
return Apply(self, [inp, kth], outs)
421425

422426
def perform(self, node, inputs, output_storage):

0 commit comments

Comments
 (0)