Skip to content

Labeled Tensors as first class objects #1421

@aseyboldt

Description

@aseyboldt

Description

Just a couple of thoughts regarding #1411.

The current implementation uses strings to identify a dimension. But maybe in the static graph framework of pytensor it would make more sense to think of them as first class objects, so that they can have their own graph-like structure? Maybe they could be a different subclass of Variable as well? I think that might lead to cleaner code for users (no typos in dimension names that lead to silent broadcasting and gigantic arrays with an out-of-memory error), and it might also make derived dimensions easier to handle and reason about?

I think we could do something like

class DimensionType(Type):
    pass

# I guess all dimensions have the same type?
DimType = DimensionType()

class Dimension(Variable):
    def __init__(self, name=None, length=None):
        pass

    def length(self) -> TensorVariable:
        ...

class DimOp(Op):
    pass

class DimConstant(DimensionVariable):
    def __init__(self, name, *, length=None):
        pass

# The result of stacking two dims
class Product(DimOp):
    __props__ = ("name",)

    def make_node(self, *inputs: DimensionVariable):
        if self.name is not None:
            name = self.name
        else:
            name = f"product[{','.join(input.name for input in inputs)}]"
        output = Dimension(name=name, length=prod(input.length for input in inputs))
        return Apply(self, inputs, [output])

def dim(name, *, length=None):
    return DimConstant(name, length=length)

def stack(variable, *, dims, name=None):
    dim_op = Product(name)
    stacked_dim = dim_op(dims)
    ...

Final usage could maybe be something like this?

country = pt.dim("country")
treatment = pt.dim("treatment")

# We can talk about the stacked dim directly:
interaction = pt.stacked_dim(country, treatment)
effect = xtensor(dim=interaction)
assert effect.unstack(interaction).dims == (country, treatment)
assert pt.stack(effect, [country, treatment]).dims == interaction

# Same for indices and slices...
subset = pt.dim_slice(country, slice("A", "B"))
sub_effect = pt.xtensor(dims=subset)
effect = pt.xtensor(dims=country)
assert effect.sel(country=slice("A", "B")).dims == (subset,)
effect = effect.at.sel(country=slice("A", "B")).add(sub_effect)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions