Skip to content

Document output_types_preference in ScalarOp #1617

@ricardoV94

Description

@ricardoV94

Description

ScalarOps are parametrized by an output_types_preference that determines the output types from the input types in make_node (unless an Op overrides make_node). This is used indirectly via output_types. The design is somewhat convoluted, but meant to make it easier to define ScalarOps (and IIRC to tweak their behavior in the InplaceOptimizer).

The whole logic should be documented, in which process, we may find out it can be simplified. Also, everytime we use it we check whether the outputs from this function are valid, which is costly. Since nobody is really creating new output_types_preference we should just trust them.

class ScalarOp(COp):
nin = -1
nout = 1
def __init__(self, output_types_preference=None, name=None):
self.name = name
if output_types_preference is not None:
if not isinstance(output_types_preference, Callable):
raise TypeError(
f"Expected a callable for the 'output_types_preference' argument to {self.__class__}. "
f"(got: {output_types_preference})"
)
self.output_types_preference = output_types_preference
elif not hasattr(self, "output_types_preference"):
self.output_types_preference = None
def make_node(self, *inputs):
if self.nin >= 0:
if len(inputs) != self.nin:
raise TypeError(
f"Wrong number of inputs for {self}.make_node "
f"(got {len(inputs)}({inputs}), expected {self.nin})"
)
inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input.type for input in inputs])]
if len(outputs) != self.nout:
inputs_str = (", ".join(str(input) for input in inputs),)
raise TypeError(
f"Not the right number of outputs produced for {self}({inputs_str}). "
f"Expected {self.nout}, got {len(outputs)}."
)
return Apply(self, inputs, outputs)
def output_types(self, types):
if self.output_types_preference is not None:
variables = self.output_types_preference(*types)
if not isinstance(variables, list | tuple) or any(
not isinstance(x, CType) for x in variables
):
raise TypeError(
"output_types_preference should return a list or a tuple of types",
self.output_types_preference,
variables,
)
if len(variables) != self.nout:
variables_str = ", ".join(str(type) for type in variables)
raise TypeError(
"Not the right number of outputs types produced for "
f"{self}({variables_str}) by {self.output_types_preference}. "
f"Expected {self.nout}, got {len(variables)}."
)
return variables
else:
raise NotImplementedError(f"Cannot calculate the output types for {self}")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions