@@ -95,7 +95,15 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
9595// Sync arrays
9696//------------------------------------------------------------------------------
9797static int CeedVectorSyncArray_Hip (const CeedVector vec , CeedMemType mem_type ) {
98- bool need_sync = false;
98+ bool need_sync = false;
99+ CeedVector_Hip * impl ;
100+
101+ // Sync for unified memory
102+ CeedCallBackend (CeedVectorGetData (vec , & impl ));
103+ if (impl -> has_unified_addressing && !impl -> h_array_borrowed ) {
104+ CeedCallHip (CeedVectorReturnCeed (vec ), hipDeviceSynchronize ());
105+ return CEED_ERROR_SUCCESS ;
106+ }
99107
100108 // Check whether device/host sync is needed
101109 CeedCallBackend (CeedVectorNeedSync_Hip (vec , mem_type , & need_sync ));
@@ -158,6 +166,10 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, Cee
158166 CeedVector_Hip * impl ;
159167
160168 CeedCallBackend (CeedVectorGetData (vec , & impl ));
169+
170+ // Use device memory for unified memory
171+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
172+
161173 switch (mem_type ) {
162174 case CEED_MEM_HOST :
163175 * has_borrowed_array_of_type = impl -> h_array_borrowed ;
@@ -202,6 +214,43 @@ static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode
202214 return CEED_ERROR_SUCCESS ;
203215}
204216
217+ //------------------------------------------------------------------------------
218+ // Set array with unified memory
219+ //------------------------------------------------------------------------------
220+ static int CeedVectorSetArrayUnifiedHostToDevice_Hip (const CeedVector vec , const CeedCopyMode copy_mode , CeedScalar * array ) {
221+ CeedSize length ;
222+ Ceed ceed ;
223+ CeedVector_Hip * impl ;
224+
225+ CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
226+ CeedCallBackend (CeedVectorGetData (vec , & impl ));
227+ CeedCallBackend (CeedVectorGetLength (vec , & length ));
228+
229+ switch (copy_mode ) {
230+ case CEED_COPY_VALUES :
231+ case CEED_OWN_POINTER :
232+ if (!impl -> d_array ) {
233+ if (impl -> d_array_borrowed ) {
234+ impl -> d_array = impl -> d_array_borrowed ;
235+ } else {
236+ if (!impl -> d_array_owned ) CeedCallHip (ceed , hipMalloc ((void * * )& impl -> d_array_owned , sizeof (CeedScalar ) * length ));
237+ impl -> d_array = impl -> d_array_owned ;
238+ }
239+ }
240+ if (array ) CeedCallHip (ceed , hipMemcpy (impl -> d_array , array , sizeof (CeedScalar ) * length , hipMemcpyHostToDevice ));
241+ if (copy_mode == CEED_OWN_POINTER ) CeedCallBackend (CeedFree (& array ));
242+ break ;
243+ case CEED_USE_POINTER :
244+ CeedCallHip (ceed , hipFree (impl -> d_array_owned ));
245+ CeedCallBackend (CeedFree (& impl -> h_array_owned ));
246+ impl -> h_array_owned = NULL ;
247+ impl -> h_array_borrowed = array ;
248+ impl -> d_array = impl -> h_array_borrowed ;
249+ }
250+ CeedCallBackend (CeedDestroy (& ceed ));
251+ return CEED_ERROR_SUCCESS ;
252+ }
253+
205254//------------------------------------------------------------------------------
206255// Set the array used by a vector,
207256// freeing any previously allocated array if applicable
@@ -213,7 +262,11 @@ static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
213262 CeedCallBackend (CeedVectorSetAllInvalid_Hip (vec ));
214263 switch (mem_type ) {
215264 case CEED_MEM_HOST :
216- return CeedVectorSetArrayHost_Hip (vec , copy_mode , array );
265+ if (impl -> has_unified_addressing ) {
266+ return CeedVectorSetArrayUnifiedHostToDevice_Hip (vec , copy_mode , array );
267+ } else {
268+ return CeedVectorSetArrayHost_Hip (vec , copy_mode , array );
269+ }
217270 case CEED_MEM_DEVICE :
218271 return CeedVectorSetArrayDevice_Hip (vec , copy_mode , array );
219272 }
@@ -303,8 +356,10 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
303356static int CeedVectorSetValue_Hip (CeedVector vec , CeedScalar val ) {
304357 CeedSize length ;
305358 CeedVector_Hip * impl ;
359+ Ceed_Hip * hip_data ;
306360
307361 CeedCallBackend (CeedVectorGetData (vec , & impl ));
362+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
308363 CeedCallBackend (CeedVectorGetLength (vec , & length ));
309364 // Set value for synced device/host array
310365 if (!impl -> d_array && !impl -> h_array ) {
@@ -321,7 +376,7 @@ static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
321376 }
322377 }
323378 if (impl -> d_array ) {
324- if (val == 0 ) {
379+ if (val == 0 && ! impl -> h_array_borrowed ) {
325380 CeedCallHip (CeedVectorReturnCeed (vec ), hipMemset (impl -> d_array , 0 , length * sizeof (CeedScalar )));
326381 } else {
327382 CeedCallBackend (CeedDeviceSetValue_Hip (impl -> d_array , length , val ));
@@ -398,14 +453,17 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
398453}
399454
400455//------------------------------------------------------------------------------
401- // Core logic for array syncronization for GetArray.
456+ // Core logic for array synchronization for GetArray.
402457// If a different memory type is most up to date, this will perform a copy
403458//------------------------------------------------------------------------------
404- static int CeedVectorGetArrayCore_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
459+ static int CeedVectorGetArrayCore_Hip (const CeedVector vec , CeedMemType mem_type , CeedScalar * * array ) {
405460 CeedVector_Hip * impl ;
406461
407462 CeedCallBackend (CeedVectorGetData (vec , & impl ));
408463
464+ // Use device memory for unified memory
465+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
466+
409467 // Sync array to requested mem_type
410468 CeedCallBackend (CeedVectorSyncArray (vec , mem_type ));
411469
@@ -431,15 +489,21 @@ static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType me
431489//------------------------------------------------------------------------------
432490// Get read/write access to a vector via the specified mem_type
433491//------------------------------------------------------------------------------
434- static int CeedVectorGetArray_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
492+ static int CeedVectorGetArray_Hip (const CeedVector vec , CeedMemType mem_type , CeedScalar * * array ) {
435493 CeedVector_Hip * impl ;
436494
437495 CeedCallBackend (CeedVectorGetData (vec , & impl ));
496+
497+ // Use device memory for unified memory
498+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
499+
500+ // 'Get' array and set only 'get'ed array as valid
438501 CeedCallBackend (CeedVectorGetArrayCore_Hip (vec , mem_type , array ));
439502 CeedCallBackend (CeedVectorSetAllInvalid_Hip (vec ));
440503 switch (mem_type ) {
441504 case CEED_MEM_HOST :
442505 impl -> h_array = * array ;
506+ if (impl -> has_unified_addressing ) impl -> d_array = * array ;
443507 break ;
444508 case CEED_MEM_DEVICE :
445509 impl -> d_array = * array ;
@@ -451,11 +515,17 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
451515//------------------------------------------------------------------------------
452516// Get write access to a vector via the specified mem_type
453517//------------------------------------------------------------------------------
454- static int CeedVectorGetArrayWrite_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
518+ static int CeedVectorGetArrayWrite_Hip (const CeedVector vec , CeedMemType mem_type , CeedScalar * * array ) {
455519 bool has_array_of_type = true;
456520 CeedVector_Hip * impl ;
521+ Ceed_Hip * hip_data ;
457522
458523 CeedCallBackend (CeedVectorGetData (vec , & impl ));
524+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
525+
526+ // Use device memory for unified memory
527+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
528+
459529 CeedCallBackend (CeedVectorHasArrayOfType_Hip (vec , mem_type , & has_array_of_type ));
460530 if (!has_array_of_type ) {
461531 // Allocate if array is not yet allocated
@@ -487,8 +557,10 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
487557 const CeedScalar * d_array ;
488558 CeedVector_Hip * impl ;
489559 hipblasHandle_t handle ;
560+ Ceed_Hip * hip_data ;
490561
491562 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
563+ CeedCallBackend (CeedGetData (ceed , & hip_data ));
492564 CeedCallBackend (CeedVectorGetData (vec , & impl ));
493565 CeedCallBackend (CeedVectorGetLength (vec , & length ));
494566 CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
@@ -518,7 +590,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
518590 CeedSize remaining_length = length - (CeedSize )(i )* INT_MAX ;
519591 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
520592
521- CeedCallHipblas (ceed , cublasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
593+ CeedCallHipblas (ceed , hipblasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
522594 * norm += sub_norm ;
523595 }
524596#endif /* HIP_VERSION */
@@ -545,7 +617,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
545617#if defined(CEED_SCALAR_IS_FP32 )
546618#if (HIP_VERSION >= 60000000 )
547619 CeedCallHipblas (ceed , hipblasSnrm2_64 (handle , (int64_t )length , (float * )d_array , 1 , (float * )norm ));
548- #else /* CUDA_VERSION */
620+ #else /* HIP_VERSION */
549621 float sub_norm = 0.0 , norm_sum = 0.0 ;
550622 float * d_array_start ;
551623
@@ -562,7 +634,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
562634#else /* CEED_SCALAR */
563635#if (HIP_VERSION >= 60000000 )
564636 CeedCallHipblas (ceed , hipblasDnrm2_64 (handle , (int64_t )length , (double * )d_array , 1 , (double * )norm ));
565- #else /* CUDA_VERSION */
637+ #else /* HIP_VERSION */
566638 double sub_norm = 0.0 , norm_sum = 0.0 ;
567639 double * d_array_start ;
568640
@@ -599,7 +671,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
599671 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
600672
601673 CeedCallHipblas (ceed , hipblasIsamax (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & index ));
602- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
674+ if (hip_data -> has_unified_addressing ) {
675+ CeedCallHip (ceed , hipDeviceSynchronize ());
676+ sub_max = fabs (d_array [index - 1 ]);
677+ } else {
678+ CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
679+ }
603680 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
604681 }
605682 * norm = current_max ;
@@ -610,7 +687,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
610687 CeedScalar norm_no_abs ;
611688
612689 CeedCallHipblas (ceed , hipblasIdamax_64 (handle , (int64_t )length , (double * )d_array , 1 , & index ));
613- CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
690+ if (hip_data -> has_unified_addressing ) {
691+ CeedCallHip (ceed , hipDeviceSynchronize ());
692+ norm_no_abs = fabs (d_array [index - 1 ]);
693+ } else {
694+ CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
695+ }
614696 * norm = fabs (norm_no_abs );
615697#else /* HIP_VERSION */
616698 CeedInt index ;
@@ -623,7 +705,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
623705 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
624706
625707 CeedCallHipblas (ceed , hipblasIdamax (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & index ));
626- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
708+ if (hip_data -> has_unified_addressing ) {
709+ CeedCallHip (ceed , hipDeviceSynchronize ());
710+ sub_max = fabs (d_array [index - 1 ]);
711+ } else {
712+ CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
713+ }
627714 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
628715 }
629716 * norm = current_max ;
@@ -854,6 +941,7 @@ static int CeedVectorDestroy_Hip(const CeedVector vec) {
854941//------------------------------------------------------------------------------
855942int CeedVectorCreate_Hip (CeedSize n , CeedVector vec ) {
856943 CeedVector_Hip * impl ;
944+ Ceed_Hip * hip_impl ;
857945 Ceed ceed ;
858946
859947 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
@@ -875,8 +963,10 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
875963 CeedCallBackend (CeedSetBackendFunction (ceed , "Vector" , vec , "AXPBY" , CeedVectorAXPBY_Hip ));
876964 CeedCallBackend (CeedSetBackendFunction (ceed , "Vector" , vec , "PointwiseMult" , CeedVectorPointwiseMult_Hip ));
877965 CeedCallBackend (CeedSetBackendFunction (ceed , "Vector" , vec , "Destroy" , CeedVectorDestroy_Hip ));
878- CeedCallBackend (CeedDestroy (& ceed ));
879966 CeedCallBackend (CeedCalloc (1 , & impl ));
967+ CeedCallBackend (CeedGetData (ceed , & hip_impl ));
968+ CeedCallBackend (CeedDestroy (& ceed ));
969+ impl -> has_unified_addressing = hip_impl -> has_unified_addressing ;
880970 CeedCallBackend (CeedVectorSetData (vec , impl ));
881971 return CEED_ERROR_SUCCESS ;
882972}
0 commit comments