@@ -171,6 +171,37 @@ void joint_matrix_store(Group g,
171171 ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
172172 size_t stride);
173173
174+ // Overloads for offset store
175+ template <typename Group, typename T, size_t Rows, size_t Cols,
176+ layout Layout, access::address_space Space,
177+ access::decorated IsDecorated>
178+ void joint_matrix_store(Group g,
179+ const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
180+ multi_ptr<T, Space, IsDecorated> base_dest, size_t row_index,
181+ size_t col_index, size_t stride);
182+
183+ template <typename Group, typename T, size_t Rows, size_t Cols,
184+ layout Layout, access::address_space Space,
185+ access::decorated IsDecorated>
186+ void joint_matrix_store(Group g,
187+ const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
188+ multi_ptr<T, Space, IsDecorated> base_dest, size_t row_index,
189+ size_t col_index, size_t stride);
190+
191+ template <typename Group, typename T, size_t Rows, size_t Cols,
192+ layout Layout, typename PropertyListT>
193+ void joint_matrix_store(Group g,
194+ const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
195+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT>
196+ base_dest, size_t row_index, size_t col_index, size_t stride);
197+
198+ template <typename Group, typename T, size_t Rows, size_t Cols,
199+ layout Layout, typename PropertyListT>
200+ void joint_matrix_store(Group g,
201+ const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
202+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT>
203+ base_dest, size_t row_index, size_t col_index, size_t stride);
204+
174205} // namespace sycl::ext::intel::experimental::matrix
175206```
176207
@@ -244,19 +275,19 @@ supporting the out of bounds checked APIs that are defined in this section.
244275In this section, we refer to the memory buffer where a `joint_matrix`
245276is loaded from or stored to as the global matrix. This global matrix
246277is also interpreted as a two-dimensional memory region as follows, where
247- `GlobalRows ` is number of rows in the global matrix, `GlobalCols ` is number of
248- columns in the global matrix, `Stride ` is number of columns that include
278+ `global_rows ` is number of rows in the global matrix, `global_cols ` is number of
279+ columns in the global matrix, `stride ` is number of columns that include
249280the out of bounds data (depicted as x here).
250281
251282```
252- GlobalCols
283+ global_cols
253284 <----------->
254285 dddddddddddddxxx ^
255- dddddddddddddxxx | GlobalRows
286+ dddddddddddddxxx | global_rows
256287 dddddddddddddxxx v
257288 xxxxxxxxxxxxxxxx
258289 <-------------->
259- Stride
290+ stride
260291```
261292
262293In the diagram above, the global matrix has 13 columns and 3
@@ -293,15 +324,15 @@ checking, namely `joint_matrix_fill`, `joint_matrix_load`, and
293324the global memory matrix, which is different from the APIs that do not
294325do bounds checking. Those non-bounds-checking APIs take a pointer to
295326the base of the joint matrix.
296- * The coordinates `RowIndex ` and `ColIndex ` into the global matrix to
327+ * The coordinates `row_index ` and `col_index ` into the global matrix to
297328calculate the pointer offset to load/store are given as separate
298329arguments.
299330* These variants take extra arguments to determine the global bounds
300- `GlobalRows ` and `GlobalCols ` of the global matrix.
331+ `global_rows ` and `global_cols ` of the global matrix.
301332
302333To illustrate the out-of-bounds checking, consider the global matrix
303- shown above which has 13 columns and 3 rows (`GlobalRows =3` and
304- `GlobalCols =13`), where the joint matrix size is 8 columns by 2 rows defined as
334+ shown above which has 13 columns and 3 rows (`global_rows =3` and
335+ `global_cols =13`), where the joint matrix size is 8 columns by 2 rows defined as
305336```
306337joint_matrix<sub_group, bfloat16, use::b, 2, 8, layout::row_major> sub_b;
307338```
@@ -311,14 +342,14 @@ both dimensions. This is shown below, where capital letters correspond
311342to the elements that are accessed by this joint matrix load:
312343
313344```
314- GlobalCols
345+ global_cols
315346 <----------->
316347 dddddddddddddxxx ^
317- dddddddddddddxxx | GlobalRows
348+ dddddddddddddxxx | global_rows
318349 ddddddddDDDDDXXX v
319350 xxxxxxxxXXXXXXXX
320351 <-------------->
321- Stride
352+ stride
322353```
323354
324355If the joint matrix is loaded via `joint_matrix_load_checked` using
@@ -335,18 +366,18 @@ namespace sycl::ext::intel::experimental::matrix {
335366template <typename Group, typename T, size_t Rows, size_t Cols,
336367 use Use, layout Layout, typename Tv>
337368void joint_matrix_fill_checked(Group g, joint_matrix<Group, T, Use, Rows,
338- Cols, Layout> &m, Tv v, size_t GlobalRows , size_t GlobalCols ,
339- size_t RowIndex , size_t ColIndex );
369+ Cols, Layout> &m, Tv v, size_t global_rows , size_t global_cols ,
370+ size_t row_index , size_t col_index );
340371
341372// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
342373template <typename Group, typename T1, typename T2,
343374 size_t Rows, size_t Cols,
344375 access::address_space Space, access::decorated IsDecorated>
345376void joint_matrix_load_checked(Group g,
346377 joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
347- multi_ptr<T2, Space, IsDecorated> base_src, size_t Stride ,
348- layout Layout, size_t GlobalRows , size_t GlobalCols ,
349- size_t RowIndex , size_t ColIndex );
378+ multi_ptr<T2, Space, IsDecorated> base_src, size_t stride ,
379+ layout Layout, size_t global_rows , size_t global_cols ,
380+ size_t row_index , size_t col_index );
350381
351382// Only available when Layout != layout::dynamic
352383// and when std::is_same_v<T1, std::remove_const_t<T2>>
@@ -356,17 +387,17 @@ template <typename Group, typename T1, typename T2,
356387 access::address_space Space, access::decorated IsDecorated>
357388void joint_matrix_load_checked(Group g,
358389 joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
359- multi_ptr<T2, Space, IsDecorated> base_src, size_t Stride ,
360- size_t GlobalRows , size_t GlobalCols , size_t RowIndex , size_t ColIndex );
390+ multi_ptr<T2, Space, IsDecorated> base_src, size_t stride ,
391+ size_t global_rows , size_t global_cols , size_t row_index , size_t col_index );
361392
362393// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
363394template <typename Group, typename T1, typename T2,
364395 size_t Rows, size_t Cols, typename PropertyListT>
365396void joint_matrix_load_checked(Group g,
366397 joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
367398 ext::oneapi::experimental::annotated_ptr<T2, PropertyListT> base_src,
368- size_t Stride , layout Layout, size_t GlobalRows , size_t GlobalCols ,
369- size_t RowIndex , size_t ColIndex );
399+ size_t stride , layout Layout, size_t global_rows , size_t global_cols ,
400+ size_t row_index , size_t col_index );
370401
371402// Only available when Layout != layout::dynamic
372403// and when std::is_same_v<T1, std::remove_const_t<T2>>
@@ -375,55 +406,55 @@ template <typename Group, typename T1, typename T2, size_t Rows,
375406void joint_matrix_load_checked(Group g,
376407 joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
377408 ext::oneapi::experimental::annotated_ptr<T2, PropertyListT> base_src,
378- size_t Stride , size_t GlobalRows , size_t GlobalCols ,
379- size_t RowIndex , size_t ColIndex );
409+ size_t stride , size_t global_rows , size_t global_cols ,
410+ size_t row_index , size_t col_index );
380411
381412template <typename Group, typename T, size_t Rows, size_t Cols,
382413 access::address_space Space, access::decorated IsDecorated>
383414void joint_matrix_store_checked(Group g,
384415 const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
385- multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride , layout Layout,
386- size_t GlobalRows , size_t GlobalCols , size_t RowIndex , size_t ColIndex );
416+ multi_ptr<T, Space, IsDecorated> base_dest, size_t stride , layout Layout,
417+ size_t global_rows , size_t global_cols , size_t row_index , size_t col_index );
387418
388419template <typename Group, typename T, size_t Rows, size_t Cols,
389420 layout Layout, access::address_space Space,
390421 access::decorated IsDecorated>
391422void joint_matrix_store_checked(Group g,
392423 const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
393- multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride ,
394- size_t GlobalRows , size_t GlobalCols , size_t RowIndex , size_t ColIndex );
424+ multi_ptr<T, Space, IsDecorated> base_dest, size_t stride ,
425+ size_t global_rows , size_t global_cols , size_t row_index , size_t col_index );
395426
396427template <typename Group, typename T, size_t Rows, size_t Cols,
397428 layout Layout, access::address_space Space,
398429 access::decorated IsDecorated>
399430void joint_matrix_store_checked(Group g,
400431 const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
401- multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride ,
402- size_t GlobalRows , size_t GlobalCols , size_t RowIndex , size_t ColIndex );
432+ multi_ptr<T, Space, IsDecorated> base_dest, size_t stride ,
433+ size_t global_rows , size_t global_cols , size_t row_index , size_t col_index );
403434
404435template <typename Group, typename T, size_t Rows, size_t Cols,
405436 typename PropertyListT>
406437void joint_matrix_store_checked(Group g,
407438 const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
408439 ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
409- size_t Stride , layout Layout, size_t GlobalRows , size_t GlobalCols ,
410- size_t RowIndex , size_t ColIndex );
440+ size_t stride , layout Layout, size_t global_rows , size_t global_cols ,
441+ size_t row_index , size_t col_index );
411442
412443template <typename Group, typename T, size_t Rows, size_t Cols,
413444 layout Layout, typename PropertyListT>
414445void joint_matrix_store_checked(Group g,
415446 const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
416447 ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
417- size_t Stride , size_t GlobalRows , size_t GlobalCols ,
418- size_t RowIndex , size_t ColIndex );
448+ size_t stride , size_t global_rows , size_t global_cols ,
449+ size_t row_index , size_t col_index );
419450
420451template <typename Group, typename T, size_t Rows, size_t Cols,
421452 layout Layout, typename PropertyListT>
422453void joint_matrix_store_checked(Group g,
423454 const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
424455 ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
425- size_t Stride , size_t GlobalRows , size_t GlobalCols ,
426- size_t RowIndex , size_t ColIndex );
456+ size_t stride , size_t global_rows , size_t global_cols ,
457+ size_t row_index , size_t col_index );
427458
428459} // namespace sycl::ext::intel::experimental::matrix
429460```
@@ -445,12 +476,12 @@ the following queries to get these requirements:
445476|Tells the required alignment (in bytes) of the base pointer for
446477`joint_matrix_load_checked` and `joint_matrix_store_checked`.
447478|`ext::intel::experimental::info::device::matrix_checked_rowindex_multiple_of<T>`|
448- `size_t`|Returns a value, of which `RowIndex ` must be multiple of;
479+ `size_t`|Returns a value, of which `row_index ` must be multiple of;
449480where `T` is the element type of the matrix. When using the matrices
450481with the machine learning types, `T` should be the element type
451482(e.g. `precision::tf32`) not the storage type.
452483|`ext::intel::experimental::info::device::matrix_checked_globalcols_multiple_of<T>`|
453- `size_t` | Returns a value, of which `GlobalCols ` must be multiple of;
484+ `size_t` | Returns a value, of which `global_cols ` must be multiple of;
454485where `T` is the element type of the matrix. When using the matrices
455486with the machine learning types, `T` should be the element type
456487(e.g. `precision::tf32`) not the storage type.
@@ -462,14 +493,19 @@ The checked APIs are currently available in devices with the architecture
462493`architecture::intel_gpu_pvc`. The following restrictions apply to
463494these checked APIs:
464495
496+ - The `stride` parameter has the following restrictions:
497+
498+ * The value `stride * sizeof(T1)` must be a multiple of 8, and
499+ * The value of `stride * sizeof(T1)` must not exceed `2^24^`.
500+
465501- The base pointer must be 4 bytes aligned.
466502
467- - For 8 bits data type, `RowIndex ` must be a multiple of 4. For 16 bits
468- data type, `RowIndex ` must be a multiple of 2. So `RowIndex ` must be a
503+ - For 8 bits data type, `row_index ` must be a multiple of 4. For 16 bits
504+ data type, `row_index ` must be a multiple of 2. So `row_index ` must be a
469505multiple of 4 divided by size of the element type (`4/sizeof(T)`).
470506
471- - For 8 bits data type, `GlobalCols ` must be a multiple of 4. For 16 bits
472- data type, `GlobalCols ` must be a multiple of 2. So `GlobalCols ` must be a
507+ - For 8 bits data type, `global_cols ` must be a multiple of 4. For 16 bits
508+ data type, `global_cols ` must be a multiple of 2. So `global_cols ` must be a
473509multiple of 4 divided by size of the element type (`4/sizeof(T)`).
474510
475511=== New Device Information Descriptor
0 commit comments