Making Broadcasting More Explicit #1187
Unanswered
Pratham-ja
asked this question in
TinyTorch Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Professor VJ and everyone else. While going through the loss module (module 04) I came across something which I used to struggle when I was learning torch using numpy from scratch. Many times there is a mathematical mistake in our operations and we are trying some element wise operation but the hidden broadcasting flexibility does not raise any error and what seems to be correct without any error raised by the numpy or torch, can lead to wrong results and not always what a student is trying to do with the tensors.
So, for example:
If I have some final prediction data and we are trying to find the mean square loss on the 32 examples and 4 output values (for example in multiple linear regression problems) then if we have
pred.shape = (32,4) and y_true.shape(4) due to some mistakes or typos I myself many times miss some dimensions while aggregating along some axis and in general) then in the implementation MSELoss class we calculate
diff = predictions.data - targets.data
This allows shape mismatches like:
pred: (32, 4)
target: (4,)
which will silently broadcast into (32,4) and we end up using the same y_true for all the examples. Although the scikit api for the mean_square_error will raise the shape mismatch error.
While TinyTorch mirrors PyTorch behavior, but sine it is more for the education purposes, exposing clearly what is going under the hood for every line of code, I believe the explicit shape matching should be enforced to prevent unintended mistakes due to broadcasting. This makes the individual understand the broadcasting in a better way and will use this feature more consciously.
Also it can be made very general in TinyTorch that whenever broadcasting occurs in some operation we can print a message that this operation involves broadcasting along some axis and this can make the user more informed and careful.
Would love to know @profvjreddi and others thought on this.
Beta Was this translation helpful? Give feedback.
All reactions