@@ -71,109 +71,6 @@ hasten neural network operations. To utilize Tensor Cores via cuBLAS
7171doing GEMM, we can use function ` cublasGemmEx ` , its signature is shown
7272in Code ` lst:cublasGemmEx ` .
7373
74- ** lst: cublasGemmEx **
75- ``` cuda
76- cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cublasComputeType_t computeType, cublasGemmAlgo_t algo)
77- ```
78-
79- ` handle ` is the cuBLAS handle, which is created using the ` cublasCreate `
80- function. ` transa ` denotes whether the matrices $\bf{A}$ and $\bf{C}$
81- are transposed, while ` transb ` denotes whether the matrix $\bf{B}$ is
82- transposed. ` m ` , ` n ` , and ` k ` are used to describe the shape of the
83- matrices. ` alpha ` and ` beta ` are used to scale the matrix multiplication
84- results. ` A ` , ` B ` , and ` C ` are pointers to the starting addresses of the
85- matrices. ` Atype ` , ` Btype ` , and ` Ctype ` describe the data type of the
86- matrices. For example, ` CUDA_R_16F ` indicates that the data is stored in
87- real 16-bit floating point type. ` lda ` , ` ldb ` , and ` ldc ` represent the
88- leading dimensions of the matrices. ` computeType ` is the data type used
89- in computation. For instance, ` CUBLAS_COMPUTE_16F ` implies the use of
90- Tensor Cores for computation in 16-bit floating point. Notably, if the
91- input data type is 32-bit float, we can use
92- ` CUBLAS_COMPUTE_32F_FAST_16F ` to perform the computation in 16-bit
93- floating point and achieve acceleration using Tensor Cores. ` algo ` is
94- the algorithm used in computation, and ` CUBLAS_GEMM_DEFAULT ` is commonly
95- used to select the default algorithm.
96-
97- ### Primitives for Hardware Units
98-
99- The second approach to accelerator programming involves the use of
100- programming primitives, such as invoking the CUDA Warp Matrix Multiply
101- Accumulate (WMMA) API on a device. This approach hinges on the
102- collaborative design of software and hardware, meaning that the design
103- of programming APIs at this level is architecture-dependent. For
104- instance, in the Volta architecture, the control object of WMMA is a
105- $16\times16$ matrix block, processed by two Tensor Cores at a time. This
106- notion is tightly linked to the integration of Tensor Cores into a SM.
107-
108- In the Volta architecture, NVIDIA offers three distinct sizes of WMMA
109- multiply-accumulate computing interfaces for FP16 input data:
110- $16\times16\times16$, $32\times8\times16$, and $8\times32\times16$.
111-
112- The basic control unit of the WMMA API is a fragment, which refers to a
113- template class that specifies information such as the meaning of
114- matrices (multiplier or accumulator), matrix shape
115- (` WMMA_M, WMMA_N, or WMMA_K ` ), data type (FP16, FP32, etc.), and layout
116- (` row_major or col_major ` ).
117- Code ` lst:frament ` shows the fragment types.
118-
119- ** lst: frament **
120- ```
121- wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
122- wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
123- wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
124- wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
125- ```
126-
127- The data of the matrix block required by multiplication operations needs
128- to be loaded to the register as a fragment. Fragments are initialized or
129- cleared after multiply-accumulate operations performed by Tensor Cores,
130- the fragments are stored back in global memory. NVIDIA provides the
131- ` wmma.load_matrix_sync() and wmma.store_matrix_sync() ` interfaces to
132- load or write the submatrix blocks. The ` wmma.fill_fragment() ` interface
133- is used to initialize the data of the corresponding fragments, and the
134- ` wmma.mma_sync() ` interface is used to perform multiply-accumulate
135- operations on fragments.
136-
137- ### Low-level Assembly Language Interface
138-
139- The PTX ISA offers another programming interface, for example, the
140- ` mma.sync.aligned.m8n8k4 ` instruction in the Volta architecture. This
141- instruction uses the shape configuration of $M=8, N=8, K=4$ to perform
142- multiply-add operations. The basic control unit of the API is the data
143- element. The matrix size (modifier ` .m8n8k4 ` ), data format (modifier
144- ` .row ` or ` .col ` ) and data formats of input accumulator D, matrix A,
145- matrix B, and output accumulator C (modifier ` .f32 ` or ` .f16 ` ) need to
146- be specified. NVIDIA's documentation[ ^ 1 ] provides information about
147- using the PTX instruction set, helping programmers compile code based on
148- the corresponding syntax rules, as shown in
149- Code ` lst:ptx ` .
150-
151- ** lst: ptx **
152- ``` cpp
153- half_t *a, *b;
154- float *C, *D;
155- unsigned const * A = reinterpret_cast <unsigned const *>(a);
156- unsigned const * B = reinterpret_cast <unsigned const *>(b);
157-
158- asm volatile (
159- "mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 "
160- "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
161- "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
162- : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]), "=f"(D[4]),
163- "=f"(D[5]), "=f"(D[6]), "=f"(D[7])
164- : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "f"(C[0]),
165- "f"(C[1]), "f"(C[2]), "f"(C[3]), "f"(C[4]), "f"(C[5]),
166- "f"(C[6]), "f"(C[7]));
167- ```
168-
169- Data elements are directly used as the input (`unsigned` type is used
170- for containing FP16 data elements). Moreover, NVIDIA provides the
171- `ldmatrix` instruction to load data from the shared memory to fragments.
172-
173- A finer-grained instruction, `mma`, can form a warp-level WMMA API of
174- more diversified shapes to control the mapping between threads and data
175- in the warp. The PTX instructions offer greater flexibility than
176- directly using CUDA C++ codes.
17774
17875[ ^ 1 ] : available at
17976 < https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html >
0 commit comments