@@ -32,7 +32,7 @@ class BatchManip : public BaseTestSuite {
3232 std::forward<F>(inner)(batch, data);
3333
3434 device->api ->freeGlobMem (data);
35- device->api ->freeGlobMem (batch);
35+ device->api ->freeUnifiedMem (batch);
3636 }
3737};
3838
@@ -166,14 +166,14 @@ TEST_F(BatchManip, uniformToScatter32) {
166166TEST_F (BatchManip, fill64) {
167167 const int N = 100 ;
168168 const int M = 120 ;
169- testWrapper<double >(N, M, false , [&](double ** batch, double * data) {
170- double scalar = 502 ;
169+ testWrapper<long >(N, M, false , [&](long ** batch, long * data) {
170+ long scalar = 502 ;
171171
172172 device->algorithms .setToValue (batch, scalar, M, N, device->api ->getDefaultStream ());
173173
174- std::vector<double > hostVector (N * M, 0 );
174+ std::vector<long > hostVector (N * M, 0 );
175175 device->api ->copyFromAsync (
176- &hostVector[0 ], data, N * M * sizeof (double ), device->api ->getDefaultStream ());
176+ &hostVector[0 ], data, N * M * sizeof (long ), device->api ->getDefaultStream ());
177177
178178 device->api ->syncDefaultStreamWithHost ();
179179
@@ -186,12 +186,12 @@ TEST_F(BatchManip, fill64) {
186186TEST_F (BatchManip, touchClean64) {
187187 const int N = 100 ;
188188 const int M = 120 ;
189- testWrapper<double >(N, M, false , [&](double ** batch, double * data) {
189+ testWrapper<long >(N, M, false , [&](long ** batch, long * data) {
190190 device->algorithms .touchBatchedMemory (batch, M, N, true , device->api ->getDefaultStream ());
191- std::vector<double > hostVector (N * M, 1 );
191+ std::vector<long > hostVector (N * M, 1 );
192192
193193 device->api ->copyFromAsync (
194- &hostVector[0 ], data, M * N * sizeof (double ), device->api ->getDefaultStream ());
194+ &hostVector[0 ], data, M * N * sizeof (long ), device->api ->getDefaultStream ());
195195
196196 device->api ->syncDefaultStreamWithHost ();
197197
@@ -200,14 +200,14 @@ TEST_F(BatchManip, touchClean64) {
200200 }
201201 });
202202
203- testWrapper<double >(N, M, true , [&](double ** batch, double * data) {
204- std::vector<double > hostVector (N * M, 1 );
203+ testWrapper<long >(N, M, true , [&](long ** batch, long * data) {
204+ std::vector<long > hostVector (N * M, 1 );
205205
206206 device->api ->copyToAsync (
207- data, &hostVector[0 ], N * M * sizeof (double ), device->api ->getDefaultStream ());
207+ data, &hostVector[0 ], N * M * sizeof (long ), device->api ->getDefaultStream ());
208208 device->algorithms .touchBatchedMemory (batch, M, N, true , device->api ->getDefaultStream ());
209209 device->api ->copyFromAsync (
210- &hostVector[0 ], data, M * N * sizeof (double ), device->api ->getDefaultStream ());
210+ &hostVector[0 ], data, M * N * sizeof (long ), device->api ->getDefaultStream ());
211211
212212 device->api ->syncDefaultStreamWithHost ();
213213
@@ -226,14 +226,14 @@ TEST_F(BatchManip, touchClean64) {
226226TEST_F (BatchManip, touchNoClean64) {
227227 const int N = 100 ;
228228 const int M = 120 ;
229- testWrapper<double >(N, M, false , [&](double ** batch, double * data) {
230- std::vector<double > hostVector (N * M, 1 );
229+ testWrapper<long >(N, M, false , [&](long ** batch, long * data) {
230+ std::vector<long > hostVector (N * M, 1 );
231231
232232 device->api ->copyToAsync (
233- data, &hostVector[0 ], N * M * sizeof (double ), device->api ->getDefaultStream ());
233+ data, &hostVector[0 ], N * M * sizeof (long ), device->api ->getDefaultStream ());
234234 device->algorithms .touchBatchedMemory (batch, M, N, false , device->api ->getDefaultStream ());
235235 device->api ->copyFromAsync (
236- &hostVector[0 ], data, N * M * sizeof (double ), device->api ->getDefaultStream ());
236+ &hostVector[0 ], data, N * M * sizeof (long ), device->api ->getDefaultStream ());
237237
238238 device->api ->syncDefaultStreamWithHost ();
239239
@@ -247,16 +247,16 @@ TEST_F(BatchManip, scatterToUniform64) {
247247 const int N = 100 ;
248248 const int M = 120 ;
249249
250- double * data2 = (double *)device->api ->allocGlobMem (N * M * sizeof (double ));
251- testWrapper<double >(N, M, false , [&](double ** batch, double * data) {
252- std::vector<double > hostVector (N * M, 1 );
250+ long * data2 = (long *)device->api ->allocGlobMem (N * M * sizeof (long ));
251+ testWrapper<long >(N, M, false , [&](long ** batch, long * data) {
252+ std::vector<long > hostVector (N * M, 1 );
253253
254254 device->api ->copyToAsync (
255- data, &hostVector[0 ], N * M * sizeof (double ), device->api ->getDefaultStream ());
255+ data, &hostVector[0 ], N * M * sizeof (long ), device->api ->getDefaultStream ());
256256 device->algorithms .copyScatterToUniform (
257- const_cast <const double **>(batch), data2, M, M, N, device->api ->getDefaultStream ());
257+ const_cast <const long **>(batch), data2, M, M, N, device->api ->getDefaultStream ());
258258 device->api ->copyFromAsync (
259- &hostVector[0 ], data2, N * M * sizeof (double ), device->api ->getDefaultStream ());
259+ &hostVector[0 ], data2, N * M * sizeof (long ), device->api ->getDefaultStream ());
260260
261261 device->api ->syncDefaultStreamWithHost ();
262262
@@ -271,15 +271,15 @@ TEST_F(BatchManip, uniformToScatter64) {
271271 const int N = 100 ;
272272 const int M = 120 ;
273273
274- double * data2 = (double *)device->api ->allocGlobMem (N * M * sizeof (double ));
275- testWrapper<double >(N, M, false , [&](double ** batch, double * data) {
276- std::vector<double > hostVector (N * M, 1 );
274+ long * data2 = (long *)device->api ->allocGlobMem (N * M * sizeof (long ));
275+ testWrapper<long >(N, M, false , [&](long ** batch, long * data) {
276+ std::vector<long > hostVector (N * M, 1 );
277277
278278 device->api ->copyToAsync (
279- data2, &hostVector[0 ], N * M * sizeof (double ), device->api ->getDefaultStream ());
279+ data2, &hostVector[0 ], N * M * sizeof (long ), device->api ->getDefaultStream ());
280280 device->algorithms .copyUniformToScatter (data2, batch, M, M, N, device->api ->getDefaultStream ());
281281 device->api ->copyFromAsync (
282- &hostVector[0 ], data, N * M * sizeof (double ), device->api ->getDefaultStream ());
282+ &hostVector[0 ], data, N * M * sizeof (long ), device->api ->getDefaultStream ());
283283
284284 device->api ->syncDefaultStreamWithHost ();
285285
0 commit comments