-
-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Hi!
I just found torchtyping a few days ago, and I am enjoying it so far. However, I am a bit confused when it comes to one particular use-case: checking if the arguments of a function share the same first dimensions.
For example, if I try to write a function such as batch-wise scalar multiplication:
def batchwise_multiply(data: TensorType['B', ...], weights: TensorType['B']):
passI get a NotImplementedError: Having dimensions to the left of ... is not currently supported.
Why is such a behaviour not implemented? What is the difference from performing the same operation on the right?
While I haven't checked the code, to the best of my understanding if TensorType[..., 'B'] is supported, then if you detect a situation like TensorType[..., 'B'], you should be able to reuse the same code but reading the tensors backwards, isn't it?
I feel this feature would be huge for the library. At least with my programming conventions, I tend to put common dimensions in leading positions so that later I can unpack tensors using the * operator.