Skip to content

Commit 475ca2d

Browse files
authored
[SYCL][Docs][Joint matrix] Add overloads and restrictions for the offset load store (#15499)
- Add missing restriction on the stride of the checked variants of load/store - Add new overloads of `joint_matrix_load` and `joint_matrix_store` where the offsets are separated from the base pointer and added as separate arguments. I kept the same name as the expectation is to remove the regular variants once the new ones are used instead. - Add restrictions on both the regular and the offset `joint_matrix_load/store` on PVC since in the current implementation, no runtime checks are added as they are expensive. The fall back to 1d load/store is done using a flag instead.
1 parent f26139e commit 475ca2d

File tree

2 files changed

+181
-44
lines changed

2 files changed

+181
-44
lines changed

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc

Lines changed: 77 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,37 @@ void joint_matrix_store(Group g,
171171
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
172172
size_t stride);
173173

174+
// Overloads for offset store
175+
template <typename Group, typename T, size_t Rows, size_t Cols,
176+
layout Layout, access::address_space Space,
177+
access::decorated IsDecorated>
178+
void joint_matrix_store(Group g,
179+
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
180+
multi_ptr<T, Space, IsDecorated> base_dest, size_t row_index,
181+
size_t col_index, size_t stride);
182+
183+
template <typename Group, typename T, size_t Rows, size_t Cols,
184+
layout Layout, access::address_space Space,
185+
access::decorated IsDecorated>
186+
void joint_matrix_store(Group g,
187+
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
188+
multi_ptr<T, Space, IsDecorated> base_dest, size_t row_index,
189+
size_t col_index, size_t stride);
190+
191+
template <typename Group, typename T, size_t Rows, size_t Cols,
192+
layout Layout, typename PropertyListT>
193+
void joint_matrix_store(Group g,
194+
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
195+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT>
196+
base_dest, size_t row_index, size_t col_index, size_t stride);
197+
198+
template <typename Group, typename T, size_t Rows, size_t Cols,
199+
layout Layout, typename PropertyListT>
200+
void joint_matrix_store(Group g,
201+
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
202+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT>
203+
base_dest, size_t row_index, size_t col_index, size_t stride);
204+
174205
} // namespace sycl::ext::intel::experimental::matrix
175206
```
176207

@@ -244,19 +275,19 @@ supporting the out of bounds checked APIs that are defined in this section.
244275
In this section, we refer to the memory buffer where a `joint_matrix`
245276
is loaded from or stored to as the global matrix. This global matrix
246277
is also interpreted as a two-dimensional memory region as follows, where
247-
`GlobalRows` is number of rows in the global matrix, `GlobalCols` is number of
248-
columns in the global matrix, `Stride` is number of columns that include
278+
`global_rows` is number of rows in the global matrix, `global_cols` is number of
279+
columns in the global matrix, `stride` is number of columns that include
249280
the out of bounds data (depicted as x here).
250281

251282
```
252-
GlobalCols
283+
global_cols
253284
<----------->
254285
dddddddddddddxxx ^
255-
dddddddddddddxxx | GlobalRows
286+
dddddddddddddxxx | global_rows
256287
dddddddddddddxxx v
257288
xxxxxxxxxxxxxxxx
258289
<-------------->
259-
Stride
290+
stride
260291
```
261292

262293
In the diagram above, the global matrix has 13 columns and 3
@@ -293,15 +324,15 @@ checking, namely `joint_matrix_fill`, `joint_matrix_load`, and
293324
the global memory matrix, which is different from the APIs that do not
294325
do bounds checking. Those non-bounds-checking APIs take a pointer to
295326
the base of the joint matrix.
296-
* The coordinates `RowIndex` and `ColIndex` into the global matrix to
327+
* The coordinates `row_index` and `col_index` into the global matrix to
297328
calculate the pointer offset to load/store are given as separate
298329
arguments.
299330
* These variants take extra arguments to determine the global bounds
300-
`GlobalRows` and `GlobalCols` of the global matrix.
331+
`global_rows` and `global_cols` of the global matrix.
301332

302333
To illustrate the out-of-bounds checking, consider the global matrix
303-
shown above which has 13 columns and 3 rows (`GlobalRows=3` and
304-
`GlobalCols=13`), where the joint matrix size is 8 columns by 2 rows defined as
334+
shown above which has 13 columns and 3 rows (`global_rows=3` and
335+
`global_cols=13`), where the joint matrix size is 8 columns by 2 rows defined as
305336
```
306337
joint_matrix<sub_group, bfloat16, use::b, 2, 8, layout::row_major> sub_b;
307338
```
@@ -311,14 +342,14 @@ both dimensions. This is shown below, where capital letters correspond
311342
to the elements that are accessed by this joint matrix load:
312343

313344
```
314-
GlobalCols
345+
global_cols
315346
<----------->
316347
dddddddddddddxxx ^
317-
dddddddddddddxxx | GlobalRows
348+
dddddddddddddxxx | global_rows
318349
ddddddddDDDDDXXX v
319350
xxxxxxxxXXXXXXXX
320351
<-------------->
321-
Stride
352+
stride
322353
```
323354

324355
If the joint matrix is loaded via `joint_matrix_load_checked` using
@@ -335,18 +366,18 @@ namespace sycl::ext::intel::experimental::matrix {
335366
template <typename Group, typename T, size_t Rows, size_t Cols,
336367
use Use, layout Layout, typename Tv>
337368
void joint_matrix_fill_checked(Group g, joint_matrix<Group, T, Use, Rows,
338-
Cols, Layout> &m, Tv v, size_t GlobalRows, size_t GlobalCols,
339-
size_t RowIndex, size_t ColIndex);
369+
Cols, Layout> &m, Tv v, size_t global_rows, size_t global_cols,
370+
size_t row_index, size_t col_index);
340371

341372
// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
342373
template <typename Group, typename T1, typename T2,
343374
size_t Rows, size_t Cols,
344375
access::address_space Space, access::decorated IsDecorated>
345376
void joint_matrix_load_checked(Group g,
346377
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
347-
multi_ptr<T2, Space, IsDecorated> base_src, size_t Stride,
348-
layout Layout, size_t GlobalRows, size_t GlobalCols,
349-
size_t RowIndex, size_t ColIndex);
378+
multi_ptr<T2, Space, IsDecorated> base_src, size_t stride,
379+
layout Layout, size_t global_rows, size_t global_cols,
380+
size_t row_index, size_t col_index);
350381

351382
// Only available when Layout != layout::dynamic
352383
// and when std::is_same_v<T1, std::remove_const_t<T2>>
@@ -356,17 +387,17 @@ template <typename Group, typename T1, typename T2,
356387
access::address_space Space, access::decorated IsDecorated>
357388
void joint_matrix_load_checked(Group g,
358389
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
359-
multi_ptr<T2, Space, IsDecorated> base_src, size_t Stride,
360-
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);
390+
multi_ptr<T2, Space, IsDecorated> base_src, size_t stride,
391+
size_t global_rows, size_t global_cols, size_t row_index, size_t col_index);
361392

362393
// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
363394
template <typename Group, typename T1, typename T2,
364395
size_t Rows, size_t Cols, typename PropertyListT>
365396
void joint_matrix_load_checked(Group g,
366397
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
367398
ext::oneapi::experimental::annotated_ptr<T2, PropertyListT> base_src,
368-
size_t Stride, layout Layout, size_t GlobalRows, size_t GlobalCols,
369-
size_t RowIndex, size_t ColIndex);
399+
size_t stride, layout Layout, size_t global_rows, size_t global_cols,
400+
size_t row_index, size_t col_index);
370401

371402
// Only available when Layout != layout::dynamic
372403
// and when std::is_same_v<T1, std::remove_const_t<T2>>
@@ -375,55 +406,55 @@ template <typename Group, typename T1, typename T2, size_t Rows,
375406
void joint_matrix_load_checked(Group g,
376407
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
377408
ext::oneapi::experimental::annotated_ptr<T2, PropertyListT> base_src,
378-
size_t Stride, size_t GlobalRows, size_t GlobalCols,
379-
size_t RowIndex, size_t ColIndex);
409+
size_t stride, size_t global_rows, size_t global_cols,
410+
size_t row_index, size_t col_index);
380411

381412
template <typename Group, typename T, size_t Rows, size_t Cols,
382413
access::address_space Space, access::decorated IsDecorated>
383414
void joint_matrix_store_checked(Group g,
384415
const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
385-
multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride, layout Layout,
386-
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);
416+
multi_ptr<T, Space, IsDecorated> base_dest, size_t stride, layout Layout,
417+
size_t global_rows, size_t global_cols, size_t row_index, size_t col_index);
387418

388419
template <typename Group, typename T, size_t Rows, size_t Cols,
389420
layout Layout, access::address_space Space,
390421
access::decorated IsDecorated>
391422
void joint_matrix_store_checked(Group g,
392423
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
393-
multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride,
394-
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);
424+
multi_ptr<T, Space, IsDecorated> base_dest, size_t stride,
425+
size_t global_rows, size_t global_cols, size_t row_index, size_t col_index);
395426

396427
template <typename Group, typename T, size_t Rows, size_t Cols,
397428
layout Layout, access::address_space Space,
398429
access::decorated IsDecorated>
399430
void joint_matrix_store_checked(Group g,
400431
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
401-
multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride,
402-
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);
432+
multi_ptr<T, Space, IsDecorated> base_dest, size_t stride,
433+
size_t global_rows, size_t global_cols, size_t row_index, size_t col_index);
403434

404435
template <typename Group, typename T, size_t Rows, size_t Cols,
405436
typename PropertyListT>
406437
void joint_matrix_store_checked(Group g,
407438
const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
408439
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
409-
size_t Stride, layout Layout, size_t GlobalRows, size_t GlobalCols,
410-
size_t RowIndex, size_t ColIndex);
440+
size_t stride, layout Layout, size_t global_rows, size_t global_cols,
441+
size_t row_index, size_t col_index);
411442

412443
template <typename Group, typename T, size_t Rows, size_t Cols,
413444
layout Layout, typename PropertyListT>
414445
void joint_matrix_store_checked(Group g,
415446
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
416447
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
417-
size_t Stride, size_t GlobalRows, size_t GlobalCols,
418-
size_t RowIndex, size_t ColIndex);
448+
size_t stride, size_t global_rows, size_t global_cols,
449+
size_t row_index, size_t col_index);
419450

420451
template <typename Group, typename T, size_t Rows, size_t Cols,
421452
layout Layout, typename PropertyListT>
422453
void joint_matrix_store_checked(Group g,
423454
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
424455
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
425-
size_t Stride, size_t GlobalRows, size_t GlobalCols,
426-
size_t RowIndex, size_t ColIndex);
456+
size_t stride, size_t global_rows, size_t global_cols,
457+
size_t row_index, size_t col_index);
427458

428459
} // namespace sycl::ext::intel::experimental::matrix
429460
```
@@ -445,12 +476,12 @@ the following queries to get these requirements:
445476
|Tells the required alignment (in bytes) of the base pointer for
446477
`joint_matrix_load_checked` and `joint_matrix_store_checked`.
447478
|`ext::intel::experimental::info::device::matrix_checked_rowindex_multiple_of<T>`|
448-
`size_t`|Returns a value, of which `RowIndex` must be multiple of;
479+
`size_t`|Returns a value, of which `row_index` must be multiple of;
449480
where `T` is the element type of the matrix. When using the matrices
450481
with the machine learning types, `T` should be the element type
451482
(e.g. `precision::tf32`) not the storage type.
452483
|`ext::intel::experimental::info::device::matrix_checked_globalcols_multiple_of<T>`|
453-
`size_t` | Returns a value, of which `GlobalCols` must be multiple of;
484+
`size_t` | Returns a value, of which `global_cols` must be multiple of;
454485
where `T` is the element type of the matrix. When using the matrices
455486
with the machine learning types, `T` should be the element type
456487
(e.g. `precision::tf32`) not the storage type.
@@ -462,14 +493,19 @@ The checked APIs are currently available in devices with the architecture
462493
`architecture::intel_gpu_pvc`. The following restrictions apply to
463494
these checked APIs:
464495

496+
- The `stride` parameter has the following restrictions:
497+
498+
* The value `stride * sizeof(T1)` must be a multiple of 8, and
499+
* The value of `stride * sizeof(T1)` must not exceed `2^24^`.
500+
465501
- The base pointer must be 4 bytes aligned.
466502

467-
- For 8 bits data type, `RowIndex` must be a multiple of 4. For 16 bits
468-
data type, `RowIndex` must be a multiple of 2. So `RowIndex` must be a
503+
- For 8 bits data type, `row_index` must be a multiple of 4. For 16 bits
504+
data type, `row_index` must be a multiple of 2. So `row_index` must be a
469505
multiple of 4 divided by size of the element type (`4/sizeof(T)`).
470506

471-
- For 8 bits data type, `GlobalCols` must be a multiple of 4. For 16 bits
472-
data type, `GlobalCols` must be a multiple of 2. So `GlobalCols` must be a
507+
- For 8 bits data type, `global_cols` must be a multiple of 4. For 16 bits
508+
data type, `global_cols` must be a multiple of 2. So `global_cols` must be a
473509
multiple of 4 divided by size of the element type (`4/sizeof(T)`).
474510

475511
=== New Device Information Descriptor

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,86 @@ of `sycl::multi_ptr`. The property list associated with the
326326
`annotated_ptr` argument represents the compile-time constant
327327
properties for cache control included in the SYCL extenion link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]
328328

329+
330+
==== Offset Load
331+
```c++
332+
namespace sycl::ext::oneapi::experimental::matrix {
333+
334+
// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
335+
template <typename Group, typename T1, typename T2,
336+
size_t Rows, size_t Cols,
337+
access::address_space Space, access::decorated IsDecorated>
338+
void joint_matrix_load(Group g,
339+
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
340+
multi_ptr<T2, Space, IsDecorated> base_src, size_t row_index,
341+
size_t col_index, size_t stride, layout Layout);
342+
343+
// Only available when Layout != layout::dynamic
344+
// and when std::is_same_v<T1, std::remove_const_t<T2>>
345+
template <typename Group, typename T1, typename T2,
346+
size_t Rows, size_t Cols,
347+
use Use, layout Layout,
348+
access::address_space Space, access::decorated IsDecorated>
349+
void joint_matrix_load(Group g,
350+
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
351+
multi_ptr<T2, Space, IsDecorated> base_src, size_t row_index,
352+
size_t col_index size_t stride);
353+
354+
// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
355+
template <typename Group, typename T1, typename T2,
356+
size_t Rows, size_t Cols,
357+
typename PropertyListT>
358+
void joint_matrix_load(Group g,
359+
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
360+
annotated_ptr<T2, PropertyListT> base_src, size_t row_index, size_t
361+
col_index, size_t stride, layout Layout);
362+
363+
// Only available when Layout != layout::dynamic
364+
// and when std::is_same_v<T1, std::remove_const_t<T2>>
365+
template <typename Group, typename T1, typename T2,
366+
size_t Rows, size_t Cols, use Use, layout Layout,
367+
typename PropertyListT>
368+
void joint_matrix_load(Group g,
369+
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
370+
annotated_ptr<T2, PropertyListT> base_src, size_t row_index, size_t
371+
col_index, size_t stride);
372+
373+
} // namespace sycl::ext::oneapi::experimental::matrix
374+
```
375+
376+
These overloads of `joint_matrix_load` takes the pointer `base_src` to
377+
designate the base pointer of the global memory matrix. The
378+
coordinates `row_index` and `col_index` into the global matrix to
379+
calculate the pointer offset to load/store are given as separate
380+
arguments.
381+
382+
==== Offset Store
383+
```c++
384+
namespace sycl::ext::oneapi::experimental::matrix {
385+
386+
// T1 must be the same as T2
387+
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
388+
access::address_space Space, access::decorated IsDecorated>
389+
void joint_matrix_store(Group g,
390+
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
391+
multi_ptr<T2, Space, IsDecorated> base_dest, size_t row_index,
392+
size_t col_index, size_t stride, layout Layout);
393+
394+
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
395+
typename PropertyListT>
396+
void joint_matrix_store(Group g,
397+
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
398+
annotated_ptr<T2, PropertyListT> base_dest, size_t row_index, size_t
399+
col_index, size_t stride, layout Layout);
400+
401+
} // namespace sycl::ext::oneapi::experimental::matrix
402+
```
403+
These overloads of `joint_matrix_store` takes the pointer `base_dest` to
404+
designate the base pointer of the global memory matrix. The
405+
coordinates `row_index` and `col_index` into the global matrix to
406+
calculate the pointer offset to load/store are given as separate
407+
arguments.
408+
329409
==== Multiply and Add
330410

331411
```c++
@@ -562,7 +642,7 @@ float *buf = malloc_shared<float>(M*K, q);
562642
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
563643
sycl::access::decorated::no>(buf);
564644

565-
joint_matrix_load(sg, tA, pBuf + Offset, Stride);
645+
joint_matrix_load(sg, tA, pBuf + Offset, stride);
566646
```
567647

568648
==== store
@@ -576,7 +656,7 @@ float *buf = malloc_shared<float>(M*K, q);
576656
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
577657
sycl::access::decorated::no>(buf);
578658

579-
joint_matrix_store(sg, tA, pBuf + Offset, Stride, layout::row_major);
659+
joint_matrix_store(sg, tA, pBuf + Offset, stride, layout::row_major);
580660
```
581661

582662
==== fill
@@ -979,7 +1059,7 @@ for (int i = 0; sizeof(combinations); i++) {
9791059
}
9801060
```
9811061

982-
=== Appendix: Supported Combinations Per Hardware
1062+
=== Appendix: Supported Combinations and Restrictions Per Hardware
9831063
The table below provides a list of the combinations that
9841064
`joint_matrix` implementations support on each of Intel AMX and Intel
9851065
XMX hardware. Note that these can be returned using
@@ -1116,6 +1196,27 @@ architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`,
11161196
`architecture::intel_gpu_lnl_m`
11171197
|======================
11181198

1199+
===== Restrictions on `architecture::intel_gpu_pvc`
1200+
1201+
- The `stride` parameter to `joint_matrix_load` and
1202+
`joint_matrix_store` has the following restrictions:
1203+
1204+
* The value `stride * sizeof(T1)` must be a multiple of 8, and
1205+
* The value of `stride * sizeof(T1)` must not exceed `2^24^`.
1206+
1207+
- The base pointer argument to `joint_matrix_load` and
1208+
`joint_matrix_store` must be 4 bytes aligned.
1209+
1210+
- In the case of the offset overloads of `joint_matrix_load` and
1211+
`joint_matrix_store`, for 8 bits data type, `row_index` must be a
1212+
multiple of 4. For 16 bits data type, `row_index` must be a multiple
1213+
of 2. So `row_index` must be a multiple of 4 divided by size of the
1214+
element type (`4/sizeof(T)`).
1215+
1216+
- If these restrictions are not satisfied, users can switch to slower
1217+
implementations of `joint_matrix_load` and `joint_matrix_store` by
1218+
setting the driver flag `IGC_JointMatrixLoadStoreOpt=1`.
1219+
11191220
==== Nvidia Tensor Cores Supported Combinations
11201221
The complete set of matrix data types and shapes that are supported by
11211222
the `ext_oneapi_cuda` backend are represented in the following

0 commit comments

Comments
 (0)