Skip to content

BUG: Passing check_parameters through minimize and tg.grad raises TypeError #1743

@Michal-Novomestsky

Description

@Michal-Novomestsky

Describe the issue:

Passing an instance of check_parameters containing scalars through minimize and then tg.grad causes the error: TypeError: Input must be a ScalarType Type

Reproducable code example:

import pytensor.tensor as pt
import pytensor
import pytensor.gradient as tg
from pymc.distributions.dist_math import check_parameters

theta = pt.scalar("theta")
x = pt.scalar("x")

obj = check_parameters(x, theta > 0, msg="theta > 0")
                   
x0, _ = pytensor.tensor.optimize.minimize(objective=obj, x=x)

tg.grad(x0, theta)

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 13
      9 obj = check_parameters(x, theta > 0, msg="theta > 0")
     11 x0, _ = pytensor.tensor.optimize.minimize(objective=obj, x=x)
---> 13 tg.grad(x0, theta)

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:746, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    743     if hasattr(g.type, "dtype"):
    744         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 746 _rval: Sequence[Variable] = _populate_grad_dict(
    747     var_to_app_to_idx, grad_dict, _wrt, cost_name
    748 )
    750 rval: MutableSequence[Variable | None] = list(_rval)
    752 for i in range(len(_rval)):

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1540, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1537     # end if cache miss
   1538     return grad_dict[var]
-> 1540 rval = [access_grad_cache(elem) for elem in wrt]
   1542 return rval

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1495, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1493 for node in node_to_idx:
   1494     for idx in node_to_idx[node]:
-> 1495         term = access_term_cache(node)[idx]
   1497         if not isinstance(term, Variable):
   1498             raise TypeError(
   1499                 f"{node.op}.grad returned {type(term)}, expected"
   1500                 " Variable instance."
   1501             )

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1170, in _populate_grad_dict.<locals>.access_term_cache(node)
   1167 if node not in term_dict:
   1168     inputs = node.inputs
-> 1170     output_grads = [access_grad_cache(var) for var in node.outputs]
   1172     # list of bools indicating if each output is connected to the cost
   1173     outputs_connected = [
   1174         not isinstance(g.type, DisconnectedType) for g in output_grads
   1175     ]

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1495, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1493 for node in node_to_idx:
   1494     for idx in node_to_idx[node]:
-> 1495         term = access_term_cache(node)[idx]
   1497         if not isinstance(term, Variable):
   1498             raise TypeError(
   1499                 f"{node.op}.grad returned {type(term)}, expected"
   1500                 " Variable instance."
   1501             )

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1170, in _populate_grad_dict.<locals>.access_term_cache(node)
   1167 if node not in term_dict:
   1168     inputs = node.inputs
-> 1170     output_grads = [access_grad_cache(var) for var in node.outputs]
   1172     # list of bools indicating if each output is connected to the cost
   1173     outputs_connected = [
   1174         not isinstance(g.type, DisconnectedType) for g in output_grads
   1175     ]

    [... skipping similar frames: _populate_grad_dict.<locals>.access_grad_cache at line 1495 (2 times), _populate_grad_dict.<locals>.access_term_cache at line 1170 (1 times)]

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1170, in _populate_grad_dict.<locals>.access_term_cache(node)
   1167 if node not in term_dict:
   1168     inputs = node.inputs
-> 1170     output_grads = [access_grad_cache(var) for var in node.outputs]
   1172     # list of bools indicating if each output is connected to the cost
   1173     outputs_connected = [
   1174         not isinstance(g.type, DisconnectedType) for g in output_grads
   1175     ]

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1495, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1493 for node in node_to_idx:
   1494     for idx in node_to_idx[node]:
-> 1495         term = access_term_cache(node)[idx]
   1497         if not isinstance(term, Variable):
   1498             raise TypeError(
   1499                 f"{node.op}.grad returned {type(term)}, expected"
   1500                 " Variable instance."
   1501             )

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/gradient.py:1325, in _populate_grad_dict.<locals>.access_term_cache(node)
   1317         if o_shape != g_shape:
   1318             raise ValueError(
   1319                 "Got a gradient of shape "
   1320                 + str(o_shape)
   1321                 + " on an output of shape "
   1322                 + str(g_shape)
   1323             )
-> 1325 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1327 if input_grads is None:
   1328     raise TypeError(
   1329         f"{node.op}.grad returned NoneType, expected iterable."
   1330     )

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/op.py:399, in Op.L_op(self, inputs, outputs, output_grads)
    372 def L_op(
    373     self,
    374     inputs: Sequence[Variable],
    375     outputs: Sequence[Variable],
    376     output_grads: Sequence[Variable],
    377 ) -> list[Variable]:
    378     r"""Construct a graph for the L-operator.
    379 
    380     The L-operator computes a row vector times the Jacobian.
   (...)    397 
    398     """
--> 399     return self.grad(inputs, output_grads)

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:696, in ScalarFromTensor.grad(self, inp, grads)
    694 (_s,) = inp
    695 (dt,) = grads
--> 696 return [tensor_from_scalar(dt)]

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/op.py:294, in Op.__call__(self, name, return_list, *inputs, **kwargs)
    250 def __call__(
    251     self, *inputs: Any, name=None, return_list=False, **kwargs
    252 ) -> Variable | list[Variable]:
    253     r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    254 
    255     This method is just a wrapper around :meth:`Op.make_node`.
   (...)    292 
    293     """
--> 294     node = self.make_node(*inputs, **kwargs)
    295     if name is not None:
    296         if len(node.outputs) == 1:

File ~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/basic.py:620, in TensorFromScalar.make_node(self, s)
    618 def make_node(self, s):
    619     if not isinstance(s.type, ps.ScalarType):
--> 620         raise TypeError("Input must be a `ScalarType` `Type`")
    622     return Apply(self, [s], [tensor(dtype=s.type.dtype, shape=())])

TypeError: Input must be a `ScalarType` `Type`

PyTensor version information:

pytensor version 2.35.1

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions