@@ -73,7 +73,18 @@ void run()
73
73
double total_time = 0.0 ;
74
74
75
75
namespace mkl_rng = oneapi::mkl::rng;
76
- mkl_rng::mcg59 engine (
76
+ #if USE_PHILOX
77
+ using EngineTypeHost = mkl_rng::philox4x32x10;
78
+ using EngineTypeDevice = mkl_rng::device::philox4x32x10<VEC_SIZE>;
79
+ #elif USE_MRG
80
+ using EngineTypeHost = mkl_rng::mrg32k3a;
81
+ using EngineTypeDevice = mkl_rng::device::mrg32k3a<VEC_SIZE>;
82
+ #else
83
+ using EngineTypeHost = mkl_rng::mcg59;
84
+ using EngineTypeDevice = mkl_rng::device::mcg59<VEC_SIZE>;
85
+ #endif
86
+
87
+ EngineTypeHost engine (
77
88
#if !INIT_ON_HOST
78
89
my_queue,
79
90
#else
@@ -86,18 +97,10 @@ void run()
86
97
auto rng_event_3 = mkl_rng::generate (mkl_rng::uniform<DataType>(1.0 , 5.0 ), engine, num_options, h_option_years_ptr);
87
98
88
99
std::size_t n_states = global_size;
89
- using EngineType =
90
- #if USE_PHILOX
91
- mkl_rng::device::philox4x32x10<VEC_SIZE>;
92
- #elif USE_MRG
93
- mkl_rng::device::mrg32k3a<VEC_SIZE>;
94
- #else
95
- mkl_rng::device::mcg59<VEC_SIZE>;
96
- #endif
97
100
98
101
// initialization needs only on first step
99
102
auto deleter = [my_queue](auto * ptr) {sycl::free (ptr, my_queue);};
100
- auto rng_states_uptr = std::unique_ptr<EngineType , decltype (deleter)>(sycl::malloc_device<EngineType >(n_states, my_queue), deleter);
103
+ auto rng_states_uptr = std::unique_ptr<EngineTypeDevice , decltype (deleter)>(sycl::malloc_device<EngineTypeDevice >(n_states, my_queue), deleter);
101
104
auto * rng_states = rng_states_uptr.get ();
102
105
103
106
my_queue.parallel_for <k_initialize_state<DataType>>(
@@ -107,9 +110,9 @@ void run()
107
110
auto id = idx[0 ];
108
111
#if USE_MRG
109
112
constexpr std::uint32_t seed = 12345u ;
110
- rng_states[id] = EngineType ({ seed, seed, seed, seed, seed, seed }, { 0 , (4096 * id) });
113
+ rng_states[id] = EngineTypeDevice ({ seed, seed, seed, seed, seed, seed }, { 0 , (4096 * id) });
111
114
#else
112
- rng_states[id] = EngineType (rand_seed, id * ITEMS_PER_WORK_ITEM * VEC_SIZE * block_n);
115
+ rng_states[id] = EngineTypeDevice (rand_seed, id * ITEMS_PER_WORK_ITEM * VEC_SIZE * block_n);
113
116
#endif
114
117
})
115
118
.wait_and_throw ();
0 commit comments