Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,12 @@ using HasUsmKind = HasProperty<usm_kind_key, PropertyListT>;
template <typename PropertyListT>
using HasBufferLocation = HasProperty<buffer_location_key, PropertyListT>;

// Get the value of a property from a property list
template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename PropertyListT>
struct GetPropertyValueFromPropList {};

template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename... Props>
struct GetPropertyValueFromPropList<PropKey, ConstType, DefaultPropVal,
detail::properties_t<Props...>> {
using prop_val_t = std::conditional_t<
detail::ContainsProperty<PropKey, std::tuple<Props...>>::value,
typename detail::FindCompileTimePropertyValueType<
PropKey, std::tuple<Props...>>::type,
DefaultPropVal>;
static constexpr ConstType value =
detail::PropertyMetaInfo<std::remove_const_t<prop_val_t>>::value;
};
detail::properties_t<Props...>>
: GetPropertyValueFromPropList<PropKey, ConstType, DefaultPropVal,
std::tuple<Props...>> {};

// Get the value of alignment from a property list
// If alignment is not present in the property list, set to default value 0
Expand Down
77 changes: 74 additions & 3 deletions sycl/include/sycl/ext/oneapi/kernel_properties/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
#pragma once

#include <array> // for array
#include <limits>
#include <stddef.h> // for size_t
#include <stdint.h> // for uint32_T
#include <sycl/aspects.hpp> // for aspect
#include <sycl/ext/oneapi/experimental/forward_progress.hpp> // for forward_progress_guarantee enum
#include <sycl/ext/oneapi/properties/property.hpp> // for PropKind
#include <sycl/ext/oneapi/properties/property_utils.hpp> // for SizeListToStr
#include <sycl/ext/oneapi/properties/property_value.hpp> // for property_value
#include <sycl/ext/oneapi/properties/properties.hpp>
#include <type_traits> // for true_type
#include <utility> // for declval
namespace sycl {
Expand Down Expand Up @@ -351,6 +350,78 @@ struct HasKernelPropertiesGetMethod<T,
decltype(std::declval<T>().get(std::declval<properties_tag>()));
};

// Trait for property compile-time meta names and values.
template <typename PropertyT> struct WGSizePropertyMetaInfo {
static constexpr std::array<size_t, 0> WGSize = {};
static constexpr size_t LinearSize = 0;
};

template <size_t Dim0, size_t... Dims>
struct WGSizePropertyMetaInfo<work_group_size_key::value_t<Dim0, Dims...>> {
static constexpr std::array<size_t, sizeof...(Dims) + 1> WGSize = {Dim0,
Dims...};
static constexpr size_t LinearSize = (Dim0 * ... * Dims);
};

template <size_t Dim0, size_t... Dims>
struct WGSizePropertyMetaInfo<max_work_group_size_key::value_t<Dim0, Dims...>> {
static constexpr std::array<size_t, sizeof...(Dims) + 1> WGSize = {Dim0,
Dims...};
static constexpr size_t LinearSize = (Dim0 * ... * Dims);
};

// Get the value of a work-group size related property from a property list
template <typename PropKey, typename PropertiesT>
struct GetWGPropertyFromPropList {};

template <typename PropKey, typename... PropertiesT>
struct GetWGPropertyFromPropList<PropKey, std::tuple<PropertiesT...>> {
using prop_val_t = std::conditional_t<
ContainsProperty<PropKey, std::tuple<PropertiesT...>>::value,
typename FindCompileTimePropertyValueType<
PropKey, std::tuple<PropertiesT...>>::type,
void>;
static constexpr auto WGSize =
WGSizePropertyMetaInfo<std::remove_const_t<prop_val_t>>::WGSize;
static constexpr size_t LinearSize =
WGSizePropertyMetaInfo<std::remove_const_t<prop_val_t>>::LinearSize;
};

// If work_group_size and max_work_group_size coexist, check that the
// dimensionality matches and that the required work-group size doesn't
// trivially exceed the maximum size.
template <typename Properties>
struct ConflictingProperties<max_work_group_size_key, Properties>
: std::false_type {
using WGSizeVal = GetWGPropertyFromPropList<work_group_size_key, Properties>;
using MaxWGSizeVal =
GetWGPropertyFromPropList<max_work_group_size_key, Properties>;
static constexpr size_t Dims = WGSizeVal::WGSize.size();
static_assert(
Dims == 0 || Dims == MaxWGSizeVal::WGSize.size(),
"work_group_size and max_work_group_size dimensionality must match");
static_assert(Dims < 1 || WGSizeVal::WGSize[0] <= MaxWGSizeVal::WGSize[0],
"work_group_size must not exceed max_work_group_size");
static_assert(Dims < 2 || WGSizeVal::WGSize[1] <= MaxWGSizeVal::WGSize[1],
"work_group_size must not exceed max_work_group_size");
static_assert(Dims < 3 || WGSizeVal::WGSize[2] <= MaxWGSizeVal::WGSize[2],
"work_group_size must not exceed max_work_group_size");
};

// If work_group_size and max_linear_work_group_size coexist, check that the
// required linear work-group size doesn't trivially exceed the maximum size.
template <typename Properties>
struct ConflictingProperties<max_linear_work_group_size_key, Properties>
: std::false_type {
using WGSizeVal = GetWGPropertyFromPropList<work_group_size_key, Properties>;
using MaxLinearWGSizeVal =
GetPropertyValueFromPropList<max_linear_work_group_size_key, size_t, void,
Properties>;
static_assert(WGSizeVal::WGSize.empty() ||
WGSizeVal::LinearSize <= MaxLinearWGSizeVal::value,
"work_group_size must not exceed max_linear_work_group_size");
};

} // namespace detail
} // namespace ext::oneapi::experimental
} // namespace _V1
Expand Down
18 changes: 18 additions & 0 deletions sycl/include/sycl/ext/oneapi/properties/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,24 @@ struct ExtractProperties<PropertyArgsT,
}
};

// Get the value of a property from a property list
template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename PropertiesT>
struct GetPropertyValueFromPropList {};

template <typename PropKey, typename ConstType, typename DefaultPropVal,
typename... PropertiesT>
struct GetPropertyValueFromPropList<PropKey, ConstType, DefaultPropVal,
std::tuple<PropertiesT...>> {
using prop_val_t = std::conditional_t<
ContainsProperty<PropKey, std::tuple<PropertiesT...>>::value,
typename FindCompileTimePropertyValueType<
PropKey, std::tuple<PropertiesT...>>::type,
DefaultPropVal>;
static constexpr ConstType value =
PropertyMetaInfo<std::remove_const_t<prop_val_t>>::value;
};

} // namespace detail

template <typename PropertiesT> class properties {
Expand Down
59 changes: 59 additions & 0 deletions sycl/test/extensions/properties/properties_kernel_negative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,68 @@ void check_sub_group_size() {
KernelFunctorWithSGSize<2>{});
}

void check_max_work_group_size() {
sycl::queue Q;

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size and max_work_group_size dimensionality must match}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 2>,
sycl::ext::oneapi::experimental::max_work_group_size<1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2>,
sycl::ext::oneapi::experimental::max_work_group_size<1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 2>,
sycl::ext::oneapi::experimental::max_work_group_size<2, 1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 2, 2>,
sycl::ext::oneapi::experimental::max_work_group_size<2, 2, 1>},
[]() {});
}

void check_max_linear_work_group_size() {
sycl::queue Q;

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_linear_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2>,
sycl::ext::oneapi::experimental::max_linear_work_group_size<1>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_linear_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 4>,
sycl::ext::oneapi::experimental::max_linear_work_group_size<7>},
[]() {});

// expected-error-re@sycl/ext/oneapi/kernel_properties/properties.hpp:* {{static assertion failed due to requirement {{.+}}: work_group_size must not exceed max_linear_work_group_size}}
Q.single_task(
sycl::ext::oneapi::experimental::properties{
sycl::ext::oneapi::experimental::work_group_size<2, 4, 2>,
sycl::ext::oneapi::experimental::max_linear_work_group_size<15>},
[]() {});
}

int main() {
check_work_group_size();
check_work_group_size_hint();
check_sub_group_size();
check_max_work_group_size();
check_max_linear_work_group_size();
return 0;
}
Loading