Skip to content

Commit 037cad5

Browse files
committed
doc: code gen batched gemm
1 parent 530fd92 commit 037cad5

File tree

1 file changed

+77
-12
lines changed

1 file changed

+77
-12
lines changed

docs_sphinx/submissions/report_25_05_15.rst

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Submission 2025-05-15
22
=====================
33

4-
Batch-Reduce GEMM
5-
-----------------
4+
Neon Batch-Reduce GEMM
5+
----------------------
66

77
This section considers a batch-reduce matrix-matrix multiplication that has a fourth dimension in addition to the known M, N, and K dimensions.
88

@@ -13,7 +13,7 @@ File: ``neon_6_1.s``
1313

1414
We started by implementing a kernel ``matmul_64_48_64`` with a batch dimension of one which is in the file ``neon_6_1_batch1.s``.
1515

16-
.. code-block::asm
16+
.. code-block:: asm
1717
:linenos:
1818
:emphasize-lines: 18
1919
@@ -51,10 +51,10 @@ We started by implementing a kernel ``matmul_64_48_64`` with a batch dimension o
5151
5252
Then we wrapped the ``matmul_64_48_64`` kernel inside another batch loop of size 16:
5353

54-
.. code-block::asm
54+
.. code-block:: asm
5555
:linenos:
5656
:emphasize-lines: 3, 41
57-
57+
5858
...
5959
mov x19, #16 // x19 iterator for the batch dimension
6060
matmul_loop_batch_dimension:
@@ -134,11 +134,11 @@ GEMM
134134
1. Extend generate to support M-N-K combinations for column-major format :math:`1 \leq M,N \leq 1024, 1 \leq K \leq 2028`
135135
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
136136

137-
To support all M-N-K combinations we take a kernel as base and dynamically generate the rest handling of not multiple of M, N and K.
138-
As a base we took the ``matmul_16m_4n_k`` kernel, which reached around ``130 GFLOPS`` as 64_48_64 kernel (i.e. the same as kernel from the previous
139-
section with the batch dimension of one).
137+
To support all combinations of M, N and K, we use one kernel as a base and dynamically generate the rest of the handling for numbers that are not multiples of M, N or K.
138+
As a base we took the ``matmul_16m_4n_k`` kernel, which reached around ``130 GFLOPS`` as 64_48_64 kernel (i.e. the same as the kernel from the
139+
previous section, with a batch dimension of one.).
140140
The k dimension is always a multiple of 1 therefore we don't need a special case for this dimension.
141-
To get full coverage on the remaining dimension, we implemented the variations:
141+
To get full coverage on the remaining dimension, we implemented the following variations:
142142

143143
- `matmul_16m_lt4nRest_k`:
144144
- M dimension must be multiple of 16
@@ -173,14 +173,79 @@ Together with the `matmul_16m_4n_k`, we have 6 kernels to cover the complete dim
173173
2. Verify all matrices for ``1≤M≤64``, ``1≤N≤64``, ``K∈[1,16,32,64,128]``,``lda=M``, ``ldb=K``, and ``ldc=M``
174174
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
175175

176-
All GEMM generation and execution with these configuration work with counting upwards and random data.
176+
All GEMM generation and execution using this configuration works with counting upwards and random data.
177177

178178
3. Verify all matrices for ``1≤M≤64``, ``1≤N≤64``, ``K∈[1,16,32,64,128]``,``lda>M``, ``ldb>K``, and ``ldc>M``
179179
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
180180

181-
All GEMM generation and execution with these configuration work with counting upwards and random data.
181+
All GEMM generation and execution using this configuration works with counting upwards and random data.
182182

183183
4. Benchmark for ``1≤M≤64``, ``1≤N≤64``, ``K∈[1,16,32,64,128]``,``lda=M``, ``ldb=K``, and ``ldc=M``.
184184
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
185185

186-
Running the Benchmark in approximately 8 hours total. We produced the following results: :download:`GEMM_benchmarks.csv <../_static/resources/report_25_05_15/GEMM_benchmarks.csv>`
186+
The benchmark took approximately eight hours in total to run. The following results were produced: :download:`GEMM_benchmarks.csv <../_static/resources/report_25_05_15/GEMM_benchmarks.csv>`
187+
188+
189+
Batch-Reduce GEMM
190+
-----------------
191+
192+
1. Extend generate to support batch dimension 1≤batch_size≤1024
193+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
194+
195+
In order to support an additional batch dimension in our implemented kernels, we placed all kernels within an additional batch loop.
196+
Consequently, the logic in our ``Brgemm.cpp`` was extended to check whether the batch dimension is greater than one.
197+
198+
.. code-block:: cpp
199+
:linenos:
200+
:emphasize-lines: 19
201+
202+
...
203+
if (dtype != dtype_t::fp32)
204+
{
205+
return error_t::err_wrong_dtype;
206+
}
207+
if (m == 0 || n == 0 || k == 0)
208+
{
209+
return error_t::err_wrong_dimension;
210+
}
211+
if ((trans_a + trans_b + trans_c) != 0)
212+
{
213+
return error_t::err_row_major_order_not_supported;
214+
}
215+
216+
if (br_size == 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
217+
{
218+
fill_with_matmuls_no_batch_dim_column_major_fp32(m, n, k);
219+
}
220+
else if (br_size > 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
221+
{
222+
fill_with_matmuls_batch_dim_column_major_fp32(m, n, k, br_size);
223+
}
224+
else
225+
{
226+
throw std::logic_error(
227+
std::format("Unhandled parameter combination found: m='{}', n='{}', k='{}', br_size='{}', trans_a='{}', trans_b='{}', "
228+
"trans_c = '{}', dtype = '{}'",
229+
m, n, k, br_size, trans_a, trans_b, trans_c, static_cast<int32_t>(dtype)));
230+
}
231+
...
232+
233+
This ``else if`` branch distributes to our extended ``br_matmul_*`` kernels with a larger batch dimension.
234+
235+
- `br_matmul_16m_lt4nRest_k`
236+
- `br_matmul_16mRest_4n_k`
237+
- `br_matmul_16mRest_lt4nRest_k`
238+
- `br_matmul_lt16_4n_k`
239+
- `br_matmul_lt16_lt4nRest_k`
240+
241+
2. Verify against reference implementation
242+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
243+
244+
All kernels were tested. The tests are located in the file ``src/test/kernels/br_matmul_*.test.cpp``.
245+
246+
The batched MatMul generation was tested for 1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128], 1≤BatchSize≤16, lda=M, ldb=K, and ldc=M. The test is located in the file ``src/test/Brgemm.test.cpp``.
247+
248+
3. Benchmark for 1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],lda=M, ldb=K,ldc=M , batch_size=16
249+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
250+
251+
The benchmark took approximately eight hours in total to run. The following results were produced:

0 commit comments

Comments
 (0)