Skip to content

Commit fc70e6f

Browse files
committed
a splitmix64 seeder for hydra, fix hdra::unweight
1 parent 3c1033c commit fc70e6f

File tree

3 files changed

+72
-7
lines changed

3 files changed

+72
-7
lines changed

hydra/Random.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,26 @@ struct is_callable: std::conditional<
150150
} // namespace detail
151151

152152

153+
/**
154+
* \ingroup random
155+
*
156+
* This functions reorder a dataset to produce a unweighted sample according to the weights
157+
* [wbegin, wend]. The length of the range [wbegin, wend] should be equal or greater than
158+
* the dataset size.
159+
*
160+
* @param policy parallel backend to perform the unweighting
161+
* @param data_begin iterator pointing to the begin of the range of weights
162+
* @param data_end iterator pointing to the begin of the range of weights
163+
* @param weights_begin iterator pointing to the begin of the range of data
164+
* @return
165+
*/
166+
template<typename RNG=default_random_engine, typename DerivedPolicy, typename IteratorData, typename IteratorWeight>
167+
typename std::enable_if<
168+
detail::random::is_iterator<IteratorData>::value && detail::random::is_iterator<IteratorWeight>::value,
169+
Range<IteratorData> >::type
170+
unweight( hydra_thrust::detail::execution_policy_base<DerivedPolicy> const& policy, IteratorData data_begin, IteratorData data_end, IteratorWeight weights_begin);
171+
172+
153173

154174

155175
/**
@@ -227,6 +247,25 @@ typename std::enable_if<
227247
unweight( IterableData data, IterableWeight weights);
228248

229249

250+
/**
251+
* \ingroup random
252+
*
253+
* This functions reorder a dataset to produce an unweighted sample according to @param functor .
254+
*
255+
* @param policy
256+
* @param begin
257+
* @param end
258+
* @param functor
259+
* @return the index of the last entry of the unweighted event.
260+
*/
261+
template<typename RNG=default_random_engine, typename Functor, typename Iterator, typename DerivedPolicy>
262+
typename std::enable_if<
263+
detail::random::is_callable<Functor>::value && detail::random::is_iterator<Iterator>::value,
264+
Range<Iterator>
265+
>::type
266+
unweight( hydra_thrust::detail::execution_policy_base<DerivedPolicy> const& policy, Iterator begin, Iterator end, Functor const& functor);
267+
268+
230269
/**
231270
* \ingroup random
232271
*

hydra/SeedRNG.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class SeedRNG
5151
{}
5252

5353
__hydra_host__ __hydra_device__
54-
inline SeetRNG& operator=(SeedRNG const& other)
54+
inline SeedRNG& operator=(SeedRNG const& other)
5555
{
5656
if(this==&other) return *this;
5757

hydra/detail/Random.inl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,12 @@
3939

4040
namespace hydra{
4141

42-
43-
template<typename RNG, typename IteratorData, typename IteratorWeight, hydra::detail::Backend BACKEND>
42+
template<typename RNG, typename DerivedPolicy, typename IteratorData, typename IteratorWeight>
4443
typename std::enable_if<
4544
detail::random::is_iterator<IteratorData>::value && detail::random::is_iterator<IteratorWeight>::value,
4645
Range<IteratorData>
4746
>::type
48-
unweight( detail::BackendPolicy<BACKEND> const& policy, IteratorData data_begin, IteratorData data_end, IteratorWeight weights_begin)
47+
unweight( hydra_thrust::detail::execution_policy_base<DerivedPolicy> const& policy, IteratorData data_begin, IteratorData data_end, IteratorWeight weights_begin)
4948
{
5049

5150
typedef typename IteratorWeight::value_type value_type;
@@ -68,6 +67,19 @@ unweight( detail::BackendPolicy<BACKEND> const& policy, IteratorData data_begin,
6867
return make_range(begin , r);
6968
}
7069

70+
71+
template<typename RNG, typename IteratorData, typename IteratorWeight, hydra::detail::Backend BACKEND>
72+
typename std::enable_if<
73+
detail::random::is_iterator<IteratorData>::value && detail::random::is_iterator<IteratorWeight>::value,
74+
Range<IteratorData>
75+
>::type
76+
unweight( detail::BackendPolicy<BACKEND> const& policy, IteratorData data_begin, IteratorData data_end, IteratorWeight weights_begin)
77+
{
78+
79+
80+
return unweight<RNG>(policy.backend, data_begin, data_end, weights_begin);
81+
}
82+
7183
template< typename RNG, typename IteratorData, typename IteratorWeight>
7284
typename std::enable_if<
7385
detail::random::is_iterator<IteratorData>::value && detail::random::is_iterator<IteratorWeight>::value,
@@ -116,17 +128,18 @@ unweight( IterableData&& data, IterableWeight&& weights)
116128
}
117129

118130

119-
template< typename RNG, typename Functor, typename Iterator, hydra::detail::Backend BACKEND>
131+
template< typename RNG, typename Functor, typename Iterator, typename DerivedPolicy>
120132
typename std::enable_if<
121133
detail::random::is_callable<Functor>::value && detail::random::is_iterator<Iterator>::value,
122134
Range<Iterator>
123135
>::type
124-
unweight( detail::BackendPolicy<BACKEND> const& policy, Iterator begin, Iterator end, Functor const& functor)
136+
unweight(hydra_thrust::detail::execution_policy_base<DerivedPolicy> const& policy, Iterator begin, Iterator end, Functor const& functor)
125137
{
126138

127139
typedef typename Functor::return_type value_type;
128140

129-
typedef hydra_thrust::pointer<value_type, typename detail::BackendPolicy<BACKEND>::execution_policy_type::tag_type> pointer_type;
141+
typedef hydra_thrust::pointer<value_type,
142+
typename hydra_thrust::detail::execution_policy_base<DerivedPolicy>::tag_type> pointer_type;
130143

131144
typedef detail::RndFlag<value_type,pointer_type, RNG > flagger_type;
132145

@@ -153,6 +166,19 @@ unweight( detail::BackendPolicy<BACKEND> const& policy, Iterator begin, Iterator
153166

154167
}
155168

169+
170+
template< typename RNG, typename Functor, typename Iterator, hydra::detail::Backend BACKEND>
171+
typename std::enable_if<
172+
detail::random::is_callable<Functor>::value && detail::random::is_iterator<Iterator>::value,
173+
Range<Iterator>
174+
>::type
175+
unweight( detail::BackendPolicy<BACKEND> const& policy, Iterator begin, Iterator end, Functor const& functor)
176+
{
177+
178+
return unweight<RNG>(policy.backend, begin, end, functor );
179+
180+
}
181+
156182
template<typename RNG, typename Functor, typename Iterator>
157183
typename std::enable_if<
158184
detail::random::is_callable<Functor>::value && detail::random::is_iterator<Iterator>::value,

0 commit comments

Comments
 (0)