-
Notifications
You must be signed in to change notification settings - Fork 1
Broadcasting Notes
I've looked into what numpy describes as broadcasting and in my opinion it is really two different PyTorch features:
- broadcasting (as defined in above link)
- batching
The rest of this document will describe for each class of function how broadcasting and/or batching should be implemented.
| torch function |
|---|
| add |
| atan2 |
| div |
| fmod |
| lerp |
| mul |
| pow |
| remainder |
| dist |
| sub |
For the out-of-place functions, these follow normal numpy-style broadcast semantics. Batching semantics don't apply, since the operations are pointwise.
In-place functions follow the same rules with the additional restriction that the resulting tensor cannot change size. Note that this is the same behavior as numpy, although I couldn't find numpy documentation to this effect.
| torch function |
|---|
| eq |
| ge |
| gt |
| le |
| lt |
| max |
| min |
| ne |
The broadcasting behavior of these functions is the same as the pointwise math functions.
| torch function |
|---|
| addcdiv |
| addcmul |
When in-place, the non-result tensors are broadcast to the size of the resulting tensor (i.e. same behavior as the 2-operand case).
For out-of-place, the behavior should be equivalent to breaking up the operation. I.e. addcmul(C,A,B) is equivalent to add(C,mul(A,B)). With broadcasting behavior, that would mean we first broadcast A and B together, then broadcast the result with C. It's easier to implement this as broadcasting all 3 operands together to start; here's a proof sketch showing they are equivalent:
consider addcmul, under expansion we want: a + (b * c) = (a + b * c) [all expanded together]
Let e(i, j) be the expansion of i with j, e(i, j, k) be the expansion of i with j,k
Then a + (b * c) = e(a, e(b,c) * e(c,b)) + e(e(b,c) * e(c,b), a)
= e(a, e(b,c)) + e(e(b,c) * e(c,b), a) (only size matters for second param)
= e(a,b,c) + e(e(b,c) * e(c,b), a) (by associativity of max in expand)
= e(a,b,c) + e(b,c,a) * e(c,b,a) (see L1)
which is a + b * c all expanded together
L1: Show e(i * j, a) = e(i,a) * e(j,a) where i,j have same size
Consider any index _{ s_0, ..., s_n}
e(i * j, a) = (i*j)_{f(s_0), ...,f(s_n)} where f is the expansion of that dimension with a
= i_{f(s_0), ..., f(s_n)} * j_{f(s_0), ..., f(s_n)} by definition of pointwise operator
= e(i,a) * e(j,a)
Some torch functions previously allowed relaxed pointwise shape requirements: as long as the number of elements were equal the arguments were accepted and the pointwise operation behaved as if the tensors were 1-d (the resulting tensor is the same shape as the first argument).
I implemented the following fallback procedure, which prefers broadcasting:
- check if the arguments are broadcastable, if they are, broadcast and run the function
- if the arguments are not broadcastable, but the number of elements are equal, fall back to previous 1-d behavior.
Note that this is fallback procedure is not backwards compatible in the case where the arguments are broadcastable and also have the same number of elements. For example, consider adding shape (1,4) and (4,1) tensors together. Previously, this would result in a (1,4) tensor, but with broadcasting semantics it now results in a (4,4) tensor. It is often obvious where in code the behavior changed because subsequent calls fail due to size mismatches, but this isn't always the case (e.g. imagine taking the sum over the resulting tensor (which represents an error), which now has 4 times as many non-negative elements).
Also note that the combination of changing to broadcasting semantics and changing the keepdims default to False on reductions can cause some calculations to "just work" when either change introduced independently would cause them to fail. For example:
running_mean = torch.randn(4)
input = torch.randn(4,4)
input_mean = input.mean(1)
diff = torch.sum( (input_mean - running_mean).abs() )
With broadcasting and keepdim=False, the sum is over 4 elements, as desired. With broadcasting and keepdim=True, the sum is over 4*4 elements.
The functions that behave this way are:
- pow
- add
- sub
- mul
- div
- fmod
- remainder
- addcmul
- addcdiv
Rules for numpy matmul are here.
PyTorch doesn't currently implement the case where the dimensionality is greater than 2:
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
For example:
>>> a_np = np.random.randn(2,5,7)
>>> b_np = np.random.randn(5,2,7,3)
>>> c_np = a_np @ b_np
>>> c_np.shape
(5, 2, 5, 3)
Note that while this is referred to as broadcasting, in PyTorch terms this is essentially broadcasting followed by batching+reshaping, i.e. PyTorch only supports the batching case where ndims(A) == ndims(B) == 3 without broadcasting.
For example:
>>> a=torch.from_numpy(a_np)
>>> b=torch.from_numpy(b_np)
>>> c=torch.bmm(a.expand(5,2,5,7).contiguous().view(5*2,5,7),b.view(5*2,7,3)).view(5,2,5,3)
>>> (c-torch.from_numpy(c_np)).abs().max()
8.881784197001252e-16
Implementing the numpy behavior won't cause any backwards incompatibility because PyTorch currently errors out if the dimensionality of either input is greater than 2.
| torch function |
|---|
| mm |
| mv |
| bmm |
If we take the view that matmul dispatches to these function to perform numpy matmul semantics, then these functions themselves don't need extra broadcasting or batching semantics, i.e. they are lower-level functions.
| torch function | numpy equivalent | supports batching | supports broadcasting |
|---|---|---|---|
| dot | numpy.dot | no | no |
| ger | numpy.outer | no | no |
Nothing to do here.
These are:
| torch function |
|---|
| addmm |
| addmv |
| addr |
These "fused" functions should behave as if the operation were broken up, e.g. addmm(C,A,B) is equivalent to add(C, mm(A,B)). Given that mm, mv, ger do not broadcast (see above), we should only broadcast the add.
| torch function |
|---|
| baddbmm |
| addbmm |
Similar logic to the non-batched functions. e.g.:
baddbmm(C,A,B) = add(C, bmm(A,B)). Since bmm does not broadcast, only the add should broadcast.
Putting these all together:
| Function | A.shape | B.shape | unfused equivalent | unfused size | broadcast of C |
|---|---|---|---|---|---|
| addmm(C,A,B) | (n,m) | (m,p) | add(C, mm(A,B)) | add(C, (n,m)) | (n,p) |
| addmv(C,A,B) | (n,m) | (m) | add(C, mv(A,B)) | add(C, (n)) | (n) |
| addr(C,A,B) | (n) | (m) | add(C, ger(A,B)) | add(C, (n,m)) | (n,m) |
| baddbmm(C,A,B) | (b,n,m) | (b,m,p) | add(C, bmm(A,B)) | add(C, (b,n,p) | (b,n,p) |
| addbmm(C,A,B) | (b,n,m) | (b,m,p) | add(C, sum(bmm(A,B),0)) | add(C, sum((b,n,p),0) = add(C, (n,p)) |
(n,p) |
Note that because PyTorch is currently strict about tensor sizes for BLAS operations, there are no backwards compatibility concerns with implementing this type of broadcasting.
There isn't a 1-to-1 mapping of numpy lapack functions to torch lapack functions, but the numpy.linalg package contains close analogs.
Many numpy.linalg functions with a single ndarray operand claim they support broadcasting (see e.g. numpy.linalg.svd), which is strange because broadcasting is only described in terms of two operands. What numpy actually seems to mean is that it supports batching+reshaping (a-la matmul), i.e.:
>>> a=np.array([[0,1], [1,1]])
>>> b=np.array([[0,1], [1,1]])
>>> np.linalg.inv(a)
array([[-1., 1.],
[ 1., 0.]])
>>> np.linalg.inv(b)
array([[-1., 1.],
[ 1., 0.]])
>>> np.linalg.inv([a,b])
array([[[-1., 1.],
[ 1., 0.]],
[[-1., 1.],
[ 1., 0.]]])
| torch function | numpy/scipy equivalent | supports batching | supports broadcasting |
|---|---|---|---|
| inverse | numpy.linalg.inv | yes | no |
| eig | numpy.linalg.eig | yes | no |
| symeig | numpy.linalg.eigh | yes | no |
| qr | numpy.linalg.qr | no | no |
| svd | numpy.linalg.svd | yes | no |
| btrifact | scipy.linalg.lu_factor | no | no |
| btrisolve | scipy.linalg.lu_solve | no | no |
| qeqrf | scipy.linalg.lapack.dqeqrt | no | no |
| orgqr | scipy.linalg.lapack.dorgqr | no | no |
| ormqr | scipy.linalg.lapack.dormqr | no | no |
| potrf | scipy.linalg.lapack.potrf | no | no |
| potri | scipy.linalg.lapack.dpotri | no | no |
| potrs | scipy.linalg.lapack.dpotrs | no | no |
| pstrf | none, numpy.linalg.cholesky closest? | yes | no |
Since none of these really support broadcasting (only batching), this should be viewed as a separate issue.
| torch function | numpy/scipy equivalent | supports batching | supports broadcasting |
|---|---|---|---|
| gels | numpy.linalg.lstsq | yes | no |
| gesv | numpy.linalg.solve | yes | yes*, see below |
numpy.linalg.solve is actually two functions:
* solve: (m,m), (m,n) -> (m,n)
* solve1: (m,m), (m) -> (m)
both of which support batching and broadcasting. solve1 is selected iff ndims(A) -1 == ndims(b). This leads to some weird results as the dimensions of the tensors change, e.g:
# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(6,15)).shape
(2, 4, 5, 9, 6, 15)
# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(9,6,15)).shape
(2, 4, 5, 9, 6, 15)
# not solve-1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(5,9,6,15)).shape
(2, 4, 5, 9, 6, 15)
# solve-1, old pattern doesn't work:
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(4,5,9,6,15)).shape
Traceback (most recent call last):
...
ValueError: solve1: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (m,m),(m)->(m) (size 15 is different from 6)
# need to match up dimensions according to solve1
>>> np.linalg.solve(np.random.randn(2,4,5,9,6,6), np.random.randn(2,4,5,9,6)).shape
(2, 4, 5, 9, 6)
This seems unnecessarily complex, i.e. different behavior should be different functions.
torch.gesv currently supports A having shape (m,m) and B having shape (m) or (m,k). B having shape (m) is treated the same as B having shape (m,1) (note: torch.gels has the same behavior).
Note that we need to make a similar decision as numpy. Consider the case where B has shape (5,5) and A has shape (5,5,5). If B is interpreted as a matrix, then the result has shape (5,5,5). If B is interpreted as a vector, equivalent to (5,5,1) (when viewed as a matrix), then the result has shape (5,5,1).
To avoid these kinds of complications, it seems nicer to just expose two different functions, gesv and gesv1, where gesv will interpret B as a matrix, except in the case where B has 1-dimension, in which case it will behave as now, and gesv1 will interpret B as a vector.
Note that this is backwards compatible because the previous shape of gesv[0] and new shape of gesv[0] for current valid inputs are equivalent:
| A.shape | B.shape | Previous shape gesv[0] | New shape gesv[0] | Shape gesv1[0] |
|---|---|---|---|---|
| (m,m) | (m) | (m,1) | (m,1) | (m) |
| (m,m) | (m,1) | (m,1) | (m,1) | (m,1) if m=1, otherwise Error (vector is size 1, not m) |
| (m,m) | (m,k), k != 1 |
(m,k) | (m,k) | (m,k) if m=k, otherwise Error (vector is size k, not m) |
Note that this isn't ideal, because in the case where B is 1-dimensional, the shape of gesv[0] is not the same as the shape of gesv1[0], but that is necessary for backwards compatibility. In other cases (e.g. with pointwise functions), we have preferred numpy semantics over backwards compatibility, but in this case, given that numpy doesn't have (in my opinion) reasonable semantics, we should prefer backwards compatibility.
The alternative, preferring consistency over backwards compatibility, is to change the output shape in the 1-dimensional B case to (m) [from (m,1)] (and for gels) as well.
| torch function | numpy/scipy equivalent | supports broadcasting |
|---|---|---|
| cat | numpy.concatenate | no |
| gather | numpy.choose | yes, although numpy.choose only supports dim=0. |
| scatter | No equivalent | Yes, for consistency with gather |
| index_select | numpy.take | no |
| masked_select | Indexing#boolean-or-mask-index-arrays | no, see advanced indexing |
Gather explanation: If we have tensor shape (x_0,x_1,...,x_n), dim=i, we want to expand index to (x0,x1,...x_i-1,f_i(index),x_i+1, .. x_n) where f_i(index) is the i-th dimensionality of the index tensor. There are three cases to consider:
| Case | Solution |
|---|---|
| i > index.dim() | expand as above, with f_i(index) == 1 and squeeze i, i.e. see * |
| i <= index.dim() < n | Note this doesn't come up for numpy. Specific example: tensor.size() == (3,5,7), index.size() == (5,7) -- do we expand like (5,1,7) or (1,5,7)? Following the rule of prepend 1s, should be (1,5,7) -> (3,5,7) |
| index.dim() == n | expand as above, with f_i(index) == index.size()[i] |
Scatter rules: If out is an (x_0, x_1, ..., x_j, ..., x_n) tensor and dim == j, then index must be an (x_0, x_1, ..., i_j, ..., x_n) tensor and source must be an (x_0, x_1, ..., s_j, ..., x_n) tensor and i_j must be <= s_j (THC currently enforces that these are equal). So, following the gather rules we have:
- out is guaranteed to be defined because this is an in-place call
- If either index or source have at least i dimensions, and the i_th dimension is d_i (let's assume they are equal like in THC), we broadcast them to (x_0, x_1, ..., d_i, ..., x_n).
- If neither index nor source have at least i dimensions, ...
*:
>>> np.array_equal(np.choose(np.broadcast_to(np.zeros(4), (1,3,4)).astype('int64'),x).reshape(3,4), np.choose(np.zeros(4).astype('int64'),x))
True
| torch function | numpy/scipy equivalent | supports broadcasting |
|---|---|---|
| copy_ | numpy.copyto | yes |
| masked_copy | numpy.copyto | yes |
| masked_fill | nothing direct, numpy.full is closest | yes, for consistency with masked_copy |
| map_ | numpy.vectorize | yes |
| map2_ | numpy.vectorize | yes |
| index_add_ | No equivalent | No, see Index example |
| index_copy_ | No equivalent | No, see Index example |
Index example:
torch.zeros(5,4).index_copy_(0, torch.LongTensor([0,2,1]), torch.arange(1,13).view(3,4))
i.e. tensor and source don't need to match.
| Type | Summary | Backwards Compatibile? | Status |
|---|---|---|---|
| pointwise | same behavior as numpy | no | PR#1563 |
| matmul | same behavior as numpy | yes | in progress |
| BLAS |
add broadcasts for fused functions, e.g. addmm
|
yes | in progress |
| LAPACK |
gesv splits into gesv and gesv1
|
yes | no plans currently |
| Indexing, Slicing, Joining, Mutating Ops |
gather, scatter should broadcast according to the (complex) rules above |
? | no plans currently |
| Uncategorized tensor functions | copy, etc. | ? | in progress |