Skip to content

Commit 05d29f3

Browse files
authored
[ESIMD] Allow full autodeduction for scatter USM APIs accepting simd_view (#13941)
1 parent dec1146 commit 05d29f3

File tree

2 files changed

+256
-10
lines changed

2 files changed

+256
-10
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 221 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,46 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
911911
}
912912
}
913913

914-
// template <typename T, int N, int VS = 1, typename OffsetT,
915-
// typename PropertyListT = empty_properties_t>
916-
// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
917-
// PropertyListT props = {}); // (usm-sc-2)
914+
/// template <int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename
915+
/// T, int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
916+
/// typename PropertyListT = empty_properties_t>
917+
/// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
918+
/// simd_mask<N / VS> mask, PropertyListT props = {});
919+
///
920+
/// Variation of the API that allows to use \c simd_view without specifying \c T
921+
/// and \c N template parameters.
922+
/// Writes ("scatters") elements of the input vector to different memory
923+
/// locations. Each memory location is base address plus an offset - a
924+
/// value of the corresponding element in the input offset vector. Access to
925+
/// any element's memory location can be disabled via the input mask.
926+
/// @tparam VS Vector size. It can also be read as the number of writes per each
927+
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
928+
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
929+
/// @param p The base address.
930+
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes.
931+
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
932+
/// If the alignment property is not passed, then it is assumed that each
933+
/// accessed address is aligned by element-size.
934+
/// @param vals The vector to scatter.
935+
/// @param mask The access mask.
936+
/// @param props The optional compile-time properties. Only 'alignment'
937+
/// and cache hint properties are used.
938+
template <
939+
int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename T,
940+
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
941+
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
942+
__ESIMD_API std::enable_if_t<
943+
detail::is_simd_view_type_v<ValuesSimdViewT> &&
944+
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
945+
scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
946+
simd_mask<N / VS> mask, PropertyListT props = {}) {
947+
scatter<T, N, VS>(p, byte_offsets, vals.read(), mask, props);
948+
}
949+
950+
/// template <typename T, int N, int VS = 1, typename OffsetT,
951+
/// typename PropertyListT = empty_properties_t>
952+
/// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
953+
/// PropertyListT props = {}); // (usm-sc-2)
918954
///
919955
/// Writes ("scatters") elements of the input vector to different memory
920956
/// locations. Each memory location is base address plus an offset - a
@@ -943,10 +979,80 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
943979
scatter<T, N, VS>(p, byte_offsets, vals, Mask, props);
944980
}
945981

946-
// template <typename T, int N, int VS = 1, typename OffsetSimdViewT,
947-
// typename PropertyListT = empty_properties_t>
948-
// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
949-
// simd_mask<N / VS> mask, PropertyListT props = {}); // (usm-sc-3)
982+
/// template <int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT,
983+
/// typename T, int N = ValuesSimdViewT::getSizeX() *
984+
/// ValuesSimdViewT::getSizeY(), typename PropertyListT = empty_properties_t>
985+
/// void scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
986+
/// simd_mask<N / VS> mask, PropertyListT props = {});
987+
///
988+
/// Variation of the API that allows to use \c simd_view without specifying \c T
989+
/// and \c N template parameters.
990+
/// Writes ("scatters") elements of the input vector to different memory
991+
/// locations. Each memory location is base address plus an offset - a
992+
/// value of the corresponding element in the input offset vector.
993+
/// @tparam VS Vector size. It can also be read as the number of writes per each
994+
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
995+
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
996+
/// @param p The base address.
997+
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes.
998+
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
999+
/// If the alignment property is not passed, then it is assumed that each
1000+
/// accessed address is aligned by element-size.
1001+
/// @param vals The vector to scatter.
1002+
/// @param mask The access mask.
1003+
/// @param props The optional compile-time properties. Only 'alignment'
1004+
/// and cache hint properties are used.
1005+
template <
1006+
int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT, typename T,
1007+
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
1008+
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
1009+
__ESIMD_API std::enable_if_t<
1010+
detail::is_simd_view_type_v<ValuesSimdViewT> &&
1011+
detail::is_simd_view_type_v<OffsetSimdViewT> &&
1012+
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
1013+
scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
1014+
simd_mask<N / VS> mask, PropertyListT props = {}) {
1015+
scatter<T, N, VS>(p, byte_offsets.read(), vals.read(), mask, props);
1016+
}
1017+
1018+
/// template <int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename
1019+
/// T, int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
1020+
/// typename PropertyListT = empty_properties_t>
1021+
/// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
1022+
/// PropertyListT props = {});
1023+
///
1024+
/// Variation of the API that allows to use \c simd_view without specifying \c T
1025+
/// and \c N template parameters.
1026+
/// Writes ("scatters") elements of the input vector to different memory
1027+
/// locations. Each memory location is base address plus an offset - a
1028+
/// value of the corresponding element in the input offset vector.
1029+
/// @tparam VS Vector size. It can also be read as the number of writes per each
1030+
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
1031+
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
1032+
/// @param p The base address.
1033+
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes.
1034+
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
1035+
/// If the alignment property is not passed, then it is assumed that each
1036+
/// accessed address is aligned by element-size.
1037+
/// @param vals The vector to scatter.
1038+
/// @param props The optional compile-time properties. Only 'alignment'
1039+
/// and cache hint properties are used.
1040+
template <
1041+
int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename T,
1042+
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
1043+
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
1044+
__ESIMD_API std::enable_if_t<
1045+
detail::is_simd_view_type_v<ValuesSimdViewT> &&
1046+
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
1047+
scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
1048+
PropertyListT props = {}) {
1049+
scatter<T, N, VS>(p, byte_offsets, vals.read(), props);
1050+
}
1051+
1052+
/// template <typename T, int N, int VS = 1, typename OffsetSimdViewT,
1053+
/// typename PropertyListT = empty_properties_t>
1054+
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
1055+
/// simd_mask<N / VS> mask, PropertyListT props = {}); // (usm-sc-3)
9501056
///
9511057
/// Writes ("scatters") elements of the input vector to different memory
9521058
/// locations. Each memory location is base address plus an offset - a
@@ -978,6 +1084,75 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
9781084
scatter<T, N, VS>(p, byte_offsets.read(), vals, mask, props);
9791085
}
9801086

1087+
/// template <int VS, typename OffsetSimdViewT, typename T, int N, typename
1088+
/// PropertyListT = empty_properties_t>
1089+
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T,N> vals,
1090+
/// simd_mask<N / VS> mask, PropertyListT props = {});
1091+
///
1092+
/// Variation of the API that allows to use \c simd_view without specifying \c T
1093+
/// and \c N template parameters.
1094+
/// Writes ("scatters") elements of the input vector to different memory
1095+
/// locations. Each memory location is base address plus an offset - a
1096+
/// value of the corresponding element in the input offset vector. Access to
1097+
/// any element's memory location can be disabled via the input mask.
1098+
/// @tparam VS Vector size. It can also be read as the number of writes per each
1099+
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
1100+
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
1101+
/// @param p The base address.
1102+
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes
1103+
/// represented as a 'simd_view' object.
1104+
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
1105+
/// If the alignment property is not passed, then it is assumed that each
1106+
/// accessed address is aligned by element-size.
1107+
/// @param vals The vector to scatter.
1108+
/// @param mask The access mask.
1109+
/// @param props The optional compile-time properties. Only 'alignment'
1110+
/// and cache hint properties are used.
1111+
template <
1112+
int VS, typename OffsetSimdViewT, typename T, int N,
1113+
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
1114+
__ESIMD_API std::enable_if_t<
1115+
detail::is_simd_view_type_v<OffsetSimdViewT> &&
1116+
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
1117+
scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
1118+
simd_mask<N / VS> mask, PropertyListT props = {}) {
1119+
scatter<T, N, VS>(p, byte_offsets.read(), vals, mask, props);
1120+
}
1121+
1122+
/// template <int VS, typename OffsetSimdViewT, typename T, int N, typename
1123+
/// PropertyListT = empty_properties_t>
1124+
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T,N> vals,
1125+
/// PropertyListT props = {});
1126+
///
1127+
/// Variation of the API that allows to use \c simd_view without specifying \c T
1128+
/// and \c N template parameters.
1129+
/// Writes ("scatters") elements of the input vector to different memory
1130+
/// locations. Each memory location is base address plus an offset - a
1131+
/// value of the corresponding element in the input offset vector. Access to
1132+
/// any element's memory location can be disabled via the input mask.
1133+
/// @tparam VS Vector size. It can also be read as the number of writes per each
1134+
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
1135+
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
1136+
/// @param p The base address.
1137+
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes
1138+
/// represented as a 'simd_view' object.
1139+
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
1140+
/// If the alignment property is not passed, then it is assumed that each
1141+
/// accessed address is aligned by element-size.
1142+
/// @param vals The vector to scatter.
1143+
/// @param props The optional compile-time properties. Only 'alignment'
1144+
/// and cache hint properties are used.
1145+
template <
1146+
int VS, typename OffsetSimdViewT, typename T, int N,
1147+
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
1148+
__ESIMD_API std::enable_if_t<
1149+
detail::is_simd_view_type_v<OffsetSimdViewT> &&
1150+
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
1151+
scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
1152+
PropertyListT props = {}) {
1153+
scatter<T, N, VS>(p, byte_offsets.read(), vals, props);
1154+
}
1155+
9811156
/// template <typename T, int N, int VS = 1, typename OffsetSimdViewT,
9821157
/// typename PropertyListT = empty_properties_t>
9831158
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
@@ -1012,6 +1187,44 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
10121187
scatter<T, N, VS>(p, byte_offsets.read(), vals, Mask, props);
10131188
}
10141189

1190+
/// template <int VS = 1, typename OffsetSimdViewT, typename
1191+
/// ValuesSimdViewT, typename T, int N = ValuesSimdViewT::getSizeX() *
1192+
/// ValuesSimdViewT::getSizeY(), typename PropertyListT =
1193+
/// empty_properties_t>
1194+
/// void scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
1195+
/// PropertyListT props = {});
1196+
///
1197+
/// Variation of the API that allows to use \c simd_view without specifying \c T
1198+
/// and \c N template parameters.
1199+
/// Writes ("scatters") elements of the input vector to different memory
1200+
/// locations. Each memory location is base address plus an offset - a
1201+
/// value of the corresponding element in the input offset vector.
1202+
/// @tparam VS Vector size. It can also be read as the number of writes per each
1203+
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
1204+
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
1205+
/// @param p The base address.
1206+
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes
1207+
/// represented as a 'simd_view' object.
1208+
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
1209+
/// If the alignment property is not passed, then it is assumed that each
1210+
/// accessed address is aligned by element-size.
1211+
/// @param vals The vector to scatter.
1212+
/// @param props The optional compile-time properties. Only 'alignment'
1213+
/// and cache hint properties are used.
1214+
template <
1215+
int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT, typename T,
1216+
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
1217+
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
1218+
__ESIMD_API std::enable_if_t<
1219+
detail::is_simd_view_type_v<OffsetSimdViewT> &&
1220+
detail::is_simd_view_type_v<ValuesSimdViewT> &&
1221+
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
1222+
scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
1223+
PropertyListT props = {}) {
1224+
simd_mask<N / VS> Mask = 1;
1225+
scatter<T, N, VS>(p, byte_offsets.read(), vals.read(), Mask, props);
1226+
}
1227+
10151228
/// A variation of \c scatter API with \c offsets represented as scalar.
10161229
///
10171230
/// @tparam Tx Element type, must be of size 4 or less.

sycl/test/esimd/memory_properties_scatter.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
9696

9797
scatter(ptrf, ioffset_n32, usm, props_align4);
9898

99-
// CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v32i1.v32i64.v32i32(<32 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 1, i8 1, i8 0, <32 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
99+
// CHECK-COUNT-22: call void @llvm.genx.lsc.store.stateless.v32i1.v32i64.v32i32(<32 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 1, i8 1, i8 0, <32 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
100100
scatter(ptrf, ioffset_n32, usm, mask_n32, props_cache_load);
101101
scatter(ptrf, ioffset_n32, usm, props_cache_load);
102102

@@ -110,6 +110,12 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
110110
props_cache_load);
111111
scatter<float, 32>(ptrf, ioffset_n32_view, usm_view, props_cache_load);
112112

113+
scatter(ptrf, ioffset_n32, usm_view, mask_n32, props_cache_load);
114+
scatter(ptrf, ioffset_n32, usm_view, props_cache_load);
115+
116+
scatter(ptrf, ioffset_n32_view, usm_view, mask_n32, props_cache_load);
117+
scatter(ptrf, ioffset_n32_view, usm_view, props_cache_load);
118+
113119
scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm, mask_n32,
114120
props_cache_load);
115121
scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm, props_cache_load);
@@ -123,9 +129,17 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
123129
usm_view.select<32, 1>(), mask_n32, props_cache_load);
124130
scatter<float, 32>(ptrf, ioffset_n32_view.select<32, 1>(),
125131
usm_view.select<32, 1>(), props_cache_load);
132+
scatter(ptrf, ioffset_n32, usm_view.select<32, 1>(), mask_n32,
133+
props_cache_load);
134+
scatter(ptrf, ioffset_n32, usm_view.select<32, 1>(), props_cache_load);
135+
136+
scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm_view.select<32, 1>(),
137+
mask_n32, props_cache_load);
138+
scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm_view.select<32, 1>(),
139+
props_cache_load);
126140

127141
// VS > 1
128-
// CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
142+
// CHECK-COUNT-24: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
129143
scatter<float, 32, 2>(ptrf, ioffset_n16, usm, mask_n16, props_cache_load);
130144

131145
scatter<float, 32, 2>(ptrf, ioffset_n16, usm, props_cache_load);
@@ -147,6 +161,16 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
147161
scatter<float, 32, 2>(ptrf, ioffset_n16_view.select<16, 1>(), usm,
148162
props_cache_load);
149163

164+
scatter<2>(ptrf, ioffset_n16, usm_view, mask_n16, props_cache_load);
165+
scatter<2>(ptrf, ioffset_n16, usm_view, props_cache_load);
166+
167+
scatter<2>(ptrf, ioffset_n16_view, usm_view, mask_n16, props_cache_load);
168+
scatter<2>(ptrf, ioffset_n16_view, usm_view, props_cache_load);
169+
170+
scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm, mask_n16,
171+
props_cache_load);
172+
scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm, props_cache_load);
173+
150174
scatter<float, 32, 2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), mask_n16,
151175
props_cache_load);
152176
scatter<float, 32, 2>(ptrf, ioffset_n16, usm_view.select<32, 1>(),
@@ -157,6 +181,15 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
157181
scatter<float, 32, 2>(ptrf, ioffset_n16_view.select<16, 1>(),
158182
usm_view.select<32, 1>(), props_cache_load);
159183

184+
scatter<2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), mask_n16,
185+
props_cache_load);
186+
scatter<2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), props_cache_load);
187+
188+
scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm_view.select<32, 1>(),
189+
mask_n16, props_cache_load);
190+
scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm_view.select<32, 1>(),
191+
props_cache_load);
192+
160193
// CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 0, i8 0, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
161194
scatter<float, 32, 2>(ptrf, ioffset_n16, usm, mask_n16);
162195

0 commit comments

Comments
 (0)