Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import TYPE_CHECKING, Union
from typing import cast as type_cast

import numba as nb
import numpy as np
from numpy.exceptions import AxisError

Expand Down Expand Up @@ -972,6 +973,7 @@ def make_node(self, a):
output = [TensorType(dtype="int64", shape=(None,))() for i in range(a.ndim)]
return Apply(self, [a], output)

@nb.njit
def perform(self, node, inp, out_):
a = inp[0]

Expand Down
Loading