Skip to content

Checking the first dimensions of a tensor #25

@adrianjav

Description

@adrianjav

Hi!

I just found torchtyping a few days ago, and I am enjoying it so far. However, I am a bit confused when it comes to one particular use-case: checking if the arguments of a function share the same first dimensions.

For example, if I try to write a function such as batch-wise scalar multiplication:

def batchwise_multiply(data: TensorType['B', ...], weights: TensorType['B']):
    pass

I get a NotImplementedError: Having dimensions to the left of ... is not currently supported.

Why is such a behaviour not implemented? What is the difference from performing the same operation on the right?
While I haven't checked the code, to the best of my understanding if TensorType[..., 'B'] is supported, then if you detect a situation like TensorType[..., 'B'], you should be able to reuse the same code but reading the tensors backwards, isn't it?

I feel this feature would be huge for the library. At least with my programming conventions, I tend to put common dimensions in leading positions so that later I can unpack tensors using the * operator.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions