-
Notifications
You must be signed in to change notification settings - Fork 52
[0035] Align matrix-vector APIs with coopvec #741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[0035] Align matrix-vector APIs with coopvec #741
Conversation
This aligns the matrix-vector APIs with the SM 6.9 cooperative vector feature such that the matrix is an `A` matrix and the vectors are column vectors rather than row vectors.
|
@mjbedy, I believe this aligns with our discussion earlier. |
| ``` llvm | ||
| declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].v[NUMi][TYi]( | ||
| immarg i32, ; opcode | ||
| %dx.types.MatrixRef, ; matrix A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW: in 0029 the op was vector followed by matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I missing something here? The description says we're aligning with the matrix-vector APIs. But as Damyan pointed out that spec has the vector followed by a matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The observation was about the mathematical operation that is being described. It's the difference between:
vec = mat * vec //SM6.9
vec = vec * mat //SM6.10
Looking at the 6.9 spec, there are a few cases it's inconsistent as well. The 6.9 spec does describe it as "Matrix-Vector Multiply", (and in fact even here the name is "matvecmul" and the matrix argument is called "matrix A".)
SM6.9 says "The @dx.op.matvecmul operation multiplies a MxK dimension matrix and a K sized input vector."
SM6.10 says "The input vector length matches the M matrix dimension" - which (if it were left as is) might be considered inconsistent as well if a B matrix is intended to be K x N dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and that's probably why I had the HLSL API that way, but I've changed it here intentionally so the HLSL API and DXIL match ordering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the description to clarify that the HLSL APIs are updated to align with the SM 6.9 HLSL API, and the DXIL is updated to align with the HLSL APIs where SM 6.9 had different argument ordering between HLSL and DXIL.
proposals/0035-linalg-matrix.md
Outdated
|
|
||
| void CoopVec() { | ||
| using namespace dx::linalg; | ||
| using MatrixBTy = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be changed to `MatrixATy' now?
proposals/0035-linalg-matrix.md
Outdated
| Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>; | ||
|
|
||
| vector<float16_t, 16> Vec = (vector<float16_t, 16>)0; | ||
| MatrixBTy MatB = MatrixBTy::Load( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'MatA' ?
mjbedy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are a few places where M, N, and K need fixing up as well.
| `Thread` scope. | ||
|
|
||
| Validation will enforce that: | ||
| * The input vector length matches the `M` matrix dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| * The input vector length matches the `M` matrix dimension | |
| * The input vector length matches the `K` matrix dimension |
| * The input vector length matches the `M` matrix dimension | ||
| * The bias vector length matches the `N` matrix dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| * The input vector length matches the `M` matrix dimension | |
| * The bias vector length matches the `N` matrix dimension | |
| * The input vector length matches the `K` matrix dimension | |
| * The bias vector length matches the `M` matrix dimension |
| vector<OutputElTy, K> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||
| vector<InputElTy, M>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| vector<OutputElTy, K> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, M>); | |
| vector<OutputElTy, M> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, K>); |
| vector<OutputElTy, K> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||
| vector<InputElTy, M>, vector<BiasElTy, K>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| vector<OutputElTy, K> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, M>, vector<BiasElTy, K>); | |
| vector<OutputElTy, M> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, K>, vector<BiasElTy, M>); |
| MxK `A` matrix with `Thread` scope, and a `K`-element vector. The operation | ||
| multiplies the `M`-element vector by the matrix then adds the `K`-element vector | ||
| producing a result `K`-element vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| MxK `A` matrix with `Thread` scope, and a `K`-element vector. The operation | |
| multiplies the `M`-element vector by the matrix then adds the `K`-element vector | |
| producing a result `K`-element vector. | |
| MxK `A` matrix with `Thread` scope, and a `M`-element vector. The operation | |
| multiplies the `K`-element vector by the matrix then adds the `M`-element vector | |
| producing a result `M`-element vector. |
mjbedy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few other places too, but the GitHub UI isn't letting me attach suggestions.
| vector<OutputElTy, K> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||
| vector<InputElTy, M>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| vector<OutputElTy, K> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, M>); | |
| vector<OutputElTy, M> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, K>); |
| vector<OutputElTy, K> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||
| vector<InputElTy, M>, vector<BiasElTy, K>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| vector<OutputElTy, K> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, M>, vector<BiasElTy, K>); | |
| vector<OutputElTy, M> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | |
| vector<InputElTy, K>, vector<BiasElTy, M>); |
This aligns the matrix-vector HLSL APIs with the SM 6.9 cooperative vector feature such that the matrix is an
Amatrix and the vectors are column vectors rather than row vectors. It also aligns the argument orders between the HLSL and DXIL APIs to make it easier to read (SM 6.9 had a mismatch between HLSL APIs and DXIL).