-
Notifications
You must be signed in to change notification settings - Fork 143
Open
Labels
Description
Description
ScalarOps are parametrized by an output_types_preference
that determines the output types from the input types in make_node
(unless an Op overrides make_node
). This is used indirectly via output_types
. The design is somewhat convoluted, but meant to make it easier to define ScalarOp
s (and IIRC to tweak their behavior in the InplaceOptimizer).
The whole logic should be documented, in which process, we may find out it can be simplified. Also, everytime we use it we check whether the outputs from this function are valid, which is costly. Since nobody is really creating new output_types_preference
we should just trust them.
pytensor/pytensor/scalar/basic.py
Lines 1219 to 1272 in ee107cb
class ScalarOp(COp): | |
nin = -1 | |
nout = 1 | |
def __init__(self, output_types_preference=None, name=None): | |
self.name = name | |
if output_types_preference is not None: | |
if not isinstance(output_types_preference, Callable): | |
raise TypeError( | |
f"Expected a callable for the 'output_types_preference' argument to {self.__class__}. " | |
f"(got: {output_types_preference})" | |
) | |
self.output_types_preference = output_types_preference | |
elif not hasattr(self, "output_types_preference"): | |
self.output_types_preference = None | |
def make_node(self, *inputs): | |
if self.nin >= 0: | |
if len(inputs) != self.nin: | |
raise TypeError( | |
f"Wrong number of inputs for {self}.make_node " | |
f"(got {len(inputs)}({inputs}), expected {self.nin})" | |
) | |
inputs = [as_scalar(input) for input in inputs] | |
outputs = [t() for t in self.output_types([input.type for input in inputs])] | |
if len(outputs) != self.nout: | |
inputs_str = (", ".join(str(input) for input in inputs),) | |
raise TypeError( | |
f"Not the right number of outputs produced for {self}({inputs_str}). " | |
f"Expected {self.nout}, got {len(outputs)}." | |
) | |
return Apply(self, inputs, outputs) | |
def output_types(self, types): | |
if self.output_types_preference is not None: | |
variables = self.output_types_preference(*types) | |
if not isinstance(variables, list | tuple) or any( | |
not isinstance(x, CType) for x in variables | |
): | |
raise TypeError( | |
"output_types_preference should return a list or a tuple of types", | |
self.output_types_preference, | |
variables, | |
) | |
if len(variables) != self.nout: | |
variables_str = ", ".join(str(type) for type in variables) | |
raise TypeError( | |
"Not the right number of outputs types produced for " | |
f"{self}({variables_str}) by {self.output_types_preference}. " | |
f"Expected {self.nout}, got {len(variables)}." | |
) | |
return variables | |
else: | |
raise NotImplementedError(f"Cannot calculate the output types for {self}") |