-
Notifications
You must be signed in to change notification settings - Fork 145
Closed
Description
Description
Just a couple of thoughts regarding #1411.
The current implementation uses strings to identify a dimension. But maybe in the static graph framework of pytensor it would make more sense to think of them as first class objects, so that they can have their own graph-like structure? Maybe they could be a different subclass of Variable
as well? I think that might lead to cleaner code for users (no typos in dimension names that lead to silent broadcasting and gigantic arrays with an out-of-memory error), and it might also make derived dimensions easier to handle and reason about?
I think we could do something like
class DimensionType(Type):
pass
# I guess all dimensions have the same type?
DimType = DimensionType()
class Dimension(Variable):
def __init__(self, name=None, length=None):
pass
def length(self) -> TensorVariable:
...
class DimOp(Op):
pass
class DimConstant(DimensionVariable):
def __init__(self, name, *, length=None):
pass
# The result of stacking two dims
class Product(DimOp):
__props__ = ("name",)
def make_node(self, *inputs: DimensionVariable):
if self.name is not None:
name = self.name
else:
name = f"product[{','.join(input.name for input in inputs)}]"
output = Dimension(name=name, length=prod(input.length for input in inputs))
return Apply(self, inputs, [output])
def dim(name, *, length=None):
return DimConstant(name, length=length)
def stack(variable, *, dims, name=None):
dim_op = Product(name)
stacked_dim = dim_op(dims)
...
Final usage could maybe be something like this?
country = pt.dim("country")
treatment = pt.dim("treatment")
# We can talk about the stacked dim directly:
interaction = pt.stacked_dim(country, treatment)
effect = xtensor(dim=interaction)
assert effect.unstack(interaction).dims == (country, treatment)
assert pt.stack(effect, [country, treatment]).dims == interaction
# Same for indices and slices...
subset = pt.dim_slice(country, slice("A", "B"))
sub_effect = pt.xtensor(dims=subset)
effect = pt.xtensor(dims=country)
assert effect.sel(country=slice("A", "B")).dims == (subset,)
effect = effect.at.sel(country=slice("A", "B")).add(sub_effect)
Metadata
Metadata
Assignees
Labels
No labels