-
Notifications
You must be signed in to change notification settings - Fork 52
[0035] Align matrix-vector APIs with coopvec #741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -190,42 +190,39 @@ Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>, | |||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K, | ||||||||||||||
| ComponentEnum MatrixDT, MatrixScopeEnum Scope> | ||||||||||||||
| vector<OutputElTy, K> Multiply(vector<InputElTy, M>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, Scope>); | ||||||||||||||
| vector<OutputElTy, K> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||||||||||||||
| vector<InputElTy, M>); | ||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, typename BiasElTy, | ||||||||||||||
| SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT, | ||||||||||||||
| MatrixScopeEnum Scope> | ||||||||||||||
| vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, Scope>, | ||||||||||||||
| vector<BiasElTy, K>); | ||||||||||||||
| vector<OutputElTy, K> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||||||||||||||
| vector<InputElTy, M>, vector<BiasElTy, K>); | ||||||||||||||
|
Comment on lines
+199
to
+200
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, | ||||||||||||||
| ComponentEnum InputInterp, typename BiasElTy, SIZE_TYPE M, | ||||||||||||||
| SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT, | ||||||||||||||
| MatrixScopeEnum Scope> | ||||||||||||||
| template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp, | ||||||||||||||
| typename BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, | ||||||||||||||
| ComponentEnum MatrixDT, MatrixScopeEnum Scope> | ||||||||||||||
| typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size == | ||||||||||||||
| M, | ||||||||||||||
| vector<OutputElTy, K> >::type | ||||||||||||||
| MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, Scope>, | ||||||||||||||
| MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||||||||||||||
| InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| vector<BiasElTy, K>); | ||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy, | ||||||||||||||
| SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT> | ||||||||||||||
| vector<OutputElTy, K> | ||||||||||||||
| MultiplyAdd(vector<InputElTy, M>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>, | ||||||||||||||
| VectorRef<BiasElTy, K>); | ||||||||||||||
| MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>, | ||||||||||||||
| vector<InputElTy, M>, VectorRef<BiasElTy, K>); | ||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, | ||||||||||||||
| ComponentEnum InputInterp, ComponentEnum BiasElTy, | ||||||||||||||
| SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT> | ||||||||||||||
| template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp, | ||||||||||||||
| ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, | ||||||||||||||
| ComponentEnum MatrixDT> | ||||||||||||||
| typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size == | ||||||||||||||
| M, | ||||||||||||||
| vector<OutputElTy, K> >::type | ||||||||||||||
| MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>, | ||||||||||||||
| MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>, | ||||||||||||||
| InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| VectorRef<BiasElTy, K>); | ||||||||||||||
|
|
||||||||||||||
| // Outer product functions | ||||||||||||||
|
|
@@ -282,32 +279,30 @@ ByteAddressBuffer B : register(t0); | |||||||||||||
|
|
||||||||||||||
| void CoopVec() { | ||||||||||||||
| using namespace dx::linalg; | ||||||||||||||
| using MatrixBTy = Matrix<ComponentType::F16, 16, 16, MatrixUse::B, | ||||||||||||||
| MatrixScope::Thread>; | ||||||||||||||
| using MatrixATy = | ||||||||||||||
| Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>; | ||||||||||||||
|
|
||||||||||||||
| vector<float16_t, 16> Vec = (vector<float16_t, 16>)0; | ||||||||||||||
| MatrixBTy MatB = MatrixBTy::Load( | ||||||||||||||
| MatrixATy MatA = MatrixATy::Load( | ||||||||||||||
| MBuf, 0, /* Row stride = number of columns * element size */ 16 * 4, | ||||||||||||||
| MatrixLayout::RowMajor); | ||||||||||||||
| vector<float16_t, 16> Layer1 = Multiply<float16_t>(Vec, MatB); | ||||||||||||||
| vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec); | ||||||||||||||
|
|
||||||||||||||
| vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0; | ||||||||||||||
| vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(Layer1, MatB, NullBias); | ||||||||||||||
| vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1, NullBias); | ||||||||||||||
|
|
||||||||||||||
| VectorRef<ComponentType::F8_E4M3, 16> MemBias = {MBuf, | ||||||||||||||
| /*start offset*/ 4096}; | ||||||||||||||
| vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(Layer2, MatB, MemBias); | ||||||||||||||
| /*start offset*/ 4096}; | ||||||||||||||
| vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2, MemBias); | ||||||||||||||
|
|
||||||||||||||
| // Clang doesn't yet support packed types. | ||||||||||||||
| #ifdef __hlsl_dx_compiler | ||||||||||||||
| vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0; | ||||||||||||||
|
|
||||||||||||||
| vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>( | ||||||||||||||
| MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, | ||||||||||||||
| MemBias); | ||||||||||||||
| MatB, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MemBias); | ||||||||||||||
| vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>( | ||||||||||||||
| MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, | ||||||||||||||
| NullBias); | ||||||||||||||
| MatB, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), NullBias); | ||||||||||||||
| #endif | ||||||||||||||
| } | ||||||||||||||
| ``` | ||||||||||||||
|
|
@@ -416,7 +411,7 @@ The following table summarizes the operations supported for each matrix scope: | |||||||||||||
| | `Matrix::SumAccumulate()` | ✗ | ✓ | ✓ | | ||||||||||||||
| | `linalg::Multiply(Matrix, Matrix)` | ✗ | ✓ | ✓ | | ||||||||||||||
| | `linalg::Multiply(vector, Matrix)` | ✓ | ✗ | ✗ | | ||||||||||||||
| | `linalg::MultiplyAdd(vector, Matrix, vector)` | ✓ | ✗ | ✗ | | ||||||||||||||
| | `linalg::MultiplyAdd(Matrix, vector, vector)` | ✓ | ✗ | ✗ | | ||||||||||||||
| | `linalg::OuterProduct(vector, vector)` | ✓ | ✓ | ✓ | | ||||||||||||||
|
|
||||||||||||||
| Throughout this document a matrix may be described as having a scope as | ||||||||||||||
|
|
@@ -925,21 +920,21 @@ infers the type of the output accumulator to match the input vector element type | |||||||||||||
| the other overload takes a template parameter for the output matrix element type. | ||||||||||||||
| All matrix scopes are allowed for the output matrix. | ||||||||||||||
|
|
||||||||||||||
| #### linalg::MultiplyAdd(vector, Matrix, vector) | ||||||||||||||
| #### linalg::MultiplyAdd(Matrix, vector, vector) | ||||||||||||||
|
|
||||||||||||||
| ``` c++ | ||||||||||||||
| template <typename OutputElTy, typename InputElTy, typename BiasElTy, uint M, | ||||||||||||||
| uint K, ComponentType MatrixDT> | ||||||||||||||
| vector<OutputElTy, K> | ||||||||||||||
| linalg::MultiplyAdd(vector<InputElTy, M>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>, | ||||||||||||||
| linalg::MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>, | ||||||||||||||
| vector<InputElTy, M>, | ||||||||||||||
| vector<BiasElTy, K>); | ||||||||||||||
| ``` | ||||||||||||||
|
|
||||||||||||||
| Requires `Thread` scope matrix input, may be called from divergent control flow. | ||||||||||||||
|
|
||||||||||||||
| The `linalg::MultiplyAdd` function has an overload that takes an `M`-element, an | ||||||||||||||
| MxK `B` matrix with `Thread` scope, and a `K`-element vector. The operation | ||||||||||||||
| MxK `A` matrix with `Thread` scope, and a `K`-element vector. The operation | ||||||||||||||
| multiplies the `M`-element vector by the matrix then adds the `K`-element vector | ||||||||||||||
| producing a result `K`-element vector. | ||||||||||||||
|
Comment on lines
+937
to
939
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
|
|
@@ -1209,37 +1204,37 @@ Must be called from wave-uniform control flow. | |||||||||||||
| ``` llvm | ||||||||||||||
| declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].v[NUMi][TYi]( | ||||||||||||||
| immarg i32, ; opcode | ||||||||||||||
| %dx.types.MatrixRef, ; matrix A | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW: in 0029 the op was vector followed by matrix.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am I missing something here? The description says we're aligning with the matrix-vector APIs. But as Damyan pointed out that spec has the vector followed by a matrix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The observation was about the mathematical operation that is being described. It's the difference between: vec = mat * vec //SM6.9 Looking at the 6.9 spec, there are a few cases it's inconsistent as well. The 6.9 spec does describe it as "Matrix-Vector Multiply", (and in fact even here the name is "matvecmul" and the matrix argument is called "matrix A".) SM6.9 says "The @dx.op.matvecmul operation multiplies a MxK dimension matrix and a K sized input vector."
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, and that's probably why I had the HLSL API that way, but I've changed it here intentionally so the HLSL API and DXIL match ordering.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've updated the description to clarify that the HLSL APIs are updated to align with the SM 6.9 HLSL API, and the DXIL is updated to align with the HLSL APIs where SM 6.9 had different argument ordering between HLSL and DXIL. |
||||||||||||||
| <[NUMi] x [TYi]>, ; input vector | ||||||||||||||
| immarg i32, ; input interpretation type (DXILComponentType) | ||||||||||||||
| %dx.types.MatrixRef ; matrix A | ||||||||||||||
| immarg i32 ; input interpretation type (DXILComponentType) | ||||||||||||||
| ) | ||||||||||||||
| ``` | ||||||||||||||
|
|
||||||||||||||
| This operation implements a row-vector multiplication against a `B` matrix of | ||||||||||||||
| This operation implements a row-vector multiplication against an `A` matrix of | ||||||||||||||
| `Thread` scope. | ||||||||||||||
|
|
||||||||||||||
| Validation will enforce that: | ||||||||||||||
| * The input vector length matches the `M` matrix dimension | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| * The matrix A is a `B` matrix of `Thread` scope | ||||||||||||||
| * The matrix A is an `A` matrix of `Thread` scope | ||||||||||||||
|
|
||||||||||||||
| ``` llvm | ||||||||||||||
| declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].v[NUMi][TYi].v[NUMo][TYb]( | ||||||||||||||
| immarg i32, ; opcode | ||||||||||||||
| %dx.types.MatrixRef, ; matrix A | ||||||||||||||
| <[NUMi] x [TYi]>, ; input vector | ||||||||||||||
| immarg i32, ; input interpretation type (DXILComponentType) | ||||||||||||||
| %dx.types.MatrixRef, ; matrix A | ||||||||||||||
| <[NUMo] x [TYb]>, ; bias vector | ||||||||||||||
| immarg i32 ; bias interpretation type (DXILComponentType) | ||||||||||||||
| ) | ||||||||||||||
| ``` | ||||||||||||||
|
|
||||||||||||||
| This operation implements a row-vector multiplication against a `B` matrix of | ||||||||||||||
| This operation implements a row-vector multiplication against an `A` matrix of | ||||||||||||||
| `Thread` scope with a bias vector added to the result. | ||||||||||||||
|
|
||||||||||||||
| Validation will enforce that: | ||||||||||||||
| * The input vector length matches the `M` matrix dimension | ||||||||||||||
| * The bias vector length matches the `N` matrix dimension | ||||||||||||||
|
Comment on lines
1235
to
1236
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| * The matrix A is a `B` matrix of `Thread` scope | ||||||||||||||
| * The matrix A is an `A` matrix of `Thread` scope | ||||||||||||||
|
|
||||||||||||||
| ```llvm | ||||||||||||||
| declare void @dx.op.matrixAccumulateToDescriptor( | ||||||||||||||
|
|
@@ -1371,7 +1366,7 @@ in the [`DXILComponentType` enumeration](#dxil-enumerations). | |||||||||||||
|
|
||||||||||||||
| ## Appendix 2: HLSL Header | ||||||||||||||
|
|
||||||||||||||
| [Compiler Explorer](https://godbolt.org/z/W5a7zbPr3) | ||||||||||||||
| [Compiler Explorer](https://godbolt.org/z/zfK5WKoYP) | ||||||||||||||
| > Note: this mostly works with Clang, but has some issues to work out still. | ||||||||||||||
|
|
||||||||||||||
| ```cpp | ||||||||||||||
|
|
@@ -1636,41 +1631,39 @@ Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>, | |||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K, | ||||||||||||||
| ComponentEnum MatrixDT, MatrixScopeEnum Scope> | ||||||||||||||
| vector<OutputElTy, K> Multiply(vector<InputElTy, M>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, Scope>); | ||||||||||||||
| vector<OutputElTy, K> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||||||||||||||
| vector<InputElTy, M>); | ||||||||||||||
|
Comment on lines
+1634
to
+1635
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, typename BiasElTy, | ||||||||||||||
| SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT, | ||||||||||||||
| MatrixScopeEnum Scope> | ||||||||||||||
| vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, Scope>, | ||||||||||||||
| vector<BiasElTy, K>); | ||||||||||||||
| vector<OutputElTy, K> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||||||||||||||
| vector<InputElTy, M>, vector<BiasElTy, K>); | ||||||||||||||
|
Comment on lines
+1640
to
+1641
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp, | ||||||||||||||
| typename BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, | ||||||||||||||
| ComponentEnum MatrixDT, MatrixScopeEnum Scope> | ||||||||||||||
| typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size == | ||||||||||||||
| M, | ||||||||||||||
| vector<OutputElTy, K> >::type | ||||||||||||||
| MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, Scope>, | ||||||||||||||
| MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>, | ||||||||||||||
| InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| vector<BiasElTy, K>); | ||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy, | ||||||||||||||
| SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT> | ||||||||||||||
| vector<OutputElTy, K> | ||||||||||||||
| MultiplyAdd(vector<InputElTy, M>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>, | ||||||||||||||
| VectorRef<BiasElTy, K>); | ||||||||||||||
| MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>, | ||||||||||||||
| vector<InputElTy, M>, VectorRef<BiasElTy, K>); | ||||||||||||||
|
|
||||||||||||||
| template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp, | ||||||||||||||
| ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, | ||||||||||||||
| ComponentEnum MatrixDT> | ||||||||||||||
| typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size == | ||||||||||||||
| M, | ||||||||||||||
| vector<OutputElTy, K> >::type | ||||||||||||||
| MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>, | ||||||||||||||
| MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>, | ||||||||||||||
| InterpretedVector<InputElTy, N, InputInterp>, | ||||||||||||||
| VectorRef<BiasElTy, K>); | ||||||||||||||
|
|
||||||||||||||
| // Outer product functions | ||||||||||||||
|
|
@@ -1719,30 +1712,30 @@ ByteAddressBuffer MBuf : register(t0); | |||||||||||||
|
|
||||||||||||||
| void CoopVec() { | ||||||||||||||
| using namespace dx::linalg; | ||||||||||||||
| using MatrixBTy = | ||||||||||||||
| Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Thread>; | ||||||||||||||
| using MatrixATy = | ||||||||||||||
| Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>; | ||||||||||||||
|
|
||||||||||||||
| vector<float16_t, 16> Vec = (vector<float16_t, 16>)0; | ||||||||||||||
| MatrixBTy MatB = MatrixBTy::Load( | ||||||||||||||
| MatrixATy MatA = MatrixATy::Load( | ||||||||||||||
| MBuf, 0, /* Row stride = number of columns * element size */ 16 * 4, | ||||||||||||||
| MatrixLayout::RowMajor); | ||||||||||||||
| vector<float16_t, 16> Layer1 = Multiply<float16_t>(Vec, MatB); | ||||||||||||||
| vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec); | ||||||||||||||
|
|
||||||||||||||
| vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0; | ||||||||||||||
| vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(Layer1, MatB, NullBias); | ||||||||||||||
| vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1, NullBias); | ||||||||||||||
|
|
||||||||||||||
| VectorRef<ComponentType::F8_E4M3, 16> MemBias = {MBuf, | ||||||||||||||
| /*start offset*/ 4096}; | ||||||||||||||
| vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(Layer2, MatB, MemBias); | ||||||||||||||
| vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2, MemBias); | ||||||||||||||
|
|
||||||||||||||
| // Clang doesn't yet support packed types. | ||||||||||||||
| #ifdef __hlsl_dx_compiler | ||||||||||||||
| vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0; | ||||||||||||||
|
|
||||||||||||||
| vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>( | ||||||||||||||
| MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, MemBias); | ||||||||||||||
| MatB, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MemBias); | ||||||||||||||
| vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>( | ||||||||||||||
| MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, NullBias); | ||||||||||||||
| MatB, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), NullBias); | ||||||||||||||
| #endif | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.