Skip to content

Commit 060d85f

Browse files
Fix gradient of minimize and root wrt higher-dimensional arguments (#1806)
* More precise imports * Fix gradient of `minimize` and `root` wrt higher-dimensional arguments Also: * Allow several inputs to be optimized * Handle disconnected and undefined gradients --------- Co-authored-by: ricardoV94 <[email protected]>
1 parent f360122 commit 060d85f

File tree

4 files changed

+562
-257
lines changed

4 files changed

+562
-257
lines changed

pytensor/sparse/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from pytensor.gradient import grad_not_implemented
1313
from pytensor.graph import Apply, Op
1414
from pytensor.link.c.op import COp
15-
from pytensor.tensor import TensorType, Variable, specify_broadcastable, tensor
16-
from pytensor.tensor.type import complex_dtypes
15+
from pytensor.tensor.shape import specify_broadcastable
16+
from pytensor.tensor.type import TensorType, Variable, complex_dtypes, tensor
1717

1818

1919
def structured_elemwise(tensor_op):

pytensor/sparse/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
)
3131
from pytensor.sparse.type import SparseTensorType
3232
from pytensor.sparse.utils import hash_from_sparse
33-
from pytensor.tensor import iscalar
3433
from pytensor.tensor.shape import shape
34+
from pytensor.tensor.type import iscalar
3535
from pytensor.tensor.variable import (
3636
TensorConstant,
3737
TensorVariable,

0 commit comments

Comments
 (0)