@@ -95,7 +95,15 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
9595// Sync arrays
9696//------------------------------------------------------------------------------
9797static int CeedVectorSyncArray_Hip (const CeedVector vec , CeedMemType mem_type ) {
98- bool need_sync = false;
98+ bool need_sync = false;
99+ CeedVector_Hip * impl ;
100+
101+ // Sync for unified memory
102+ CeedCallBackend (CeedVectorGetData (vec , & impl ));
103+ if (impl -> has_unified_addressing && !impl -> h_array_borrowed ) {
104+ CeedCallHip (CeedVectorReturnCeed (vec ), hipDeviceSynchronize ());
105+ return CEED_ERROR_SUCCESS ;
106+ }
99107
100108 // Check whether device/host sync is needed
101109 CeedCallBackend (CeedVectorNeedSync_Hip (vec , mem_type , & need_sync ));
@@ -158,6 +166,10 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, Cee
158166 CeedVector_Hip * impl ;
159167
160168 CeedCallBackend (CeedVectorGetData (vec , & impl ));
169+
170+ // Use device memory for unified memory
171+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
172+
161173 switch (mem_type ) {
162174 case CEED_MEM_HOST :
163175 * has_borrowed_array_of_type = impl -> h_array_borrowed ;
@@ -202,6 +214,43 @@ static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode
202214 return CEED_ERROR_SUCCESS ;
203215}
204216
217+ //------------------------------------------------------------------------------
218+ // Set array with unified memory
219+ //------------------------------------------------------------------------------
220+ static int CeedVectorSetArrayUnifiedHostToDevice_Hip (const CeedVector vec , const CeedCopyMode copy_mode , CeedScalar * array ) {
221+ CeedSize length ;
222+ Ceed ceed ;
223+ CeedVector_Hip * impl ;
224+
225+ CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
226+ CeedCallBackend (CeedVectorGetData (vec , & impl ));
227+ CeedCallBackend (CeedVectorGetLength (vec , & length ));
228+
229+ switch (copy_mode ) {
230+ case CEED_COPY_VALUES :
231+ case CEED_OWN_POINTER :
232+ if (!impl -> d_array ) {
233+ if (impl -> d_array_borrowed ) {
234+ impl -> d_array = impl -> d_array_borrowed ;
235+ } else {
236+ if (!impl -> d_array_owned ) CeedCallHip (ceed , hipMalloc ((void * * )& impl -> d_array_owned , sizeof (CeedScalar ) * length ));
237+ impl -> d_array = impl -> d_array_owned ;
238+ }
239+ }
240+ if (array ) CeedCallHip (ceed , hipMemcpy (impl -> d_array , array , sizeof (CeedScalar ) * length , hipMemcpyHostToDevice ));
241+ if (copy_mode == CEED_OWN_POINTER ) CeedCallBackend (CeedFree (& array ));
242+ break ;
243+ case CEED_USE_POINTER :
244+ CeedCallHip (ceed , hipFree (impl -> d_array_owned ));
245+ CeedCallBackend (CeedFree (& impl -> h_array_owned ));
246+ impl -> h_array_owned = NULL ;
247+ impl -> h_array_borrowed = array ;
248+ impl -> d_array = impl -> h_array_borrowed ;
249+ }
250+ CeedCallBackend (CeedDestroy (& ceed ));
251+ return CEED_ERROR_SUCCESS ;
252+ }
253+
205254//------------------------------------------------------------------------------
206255// Set the array used by a vector,
207256// freeing any previously allocated array if applicable
@@ -213,7 +262,11 @@ static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
213262 CeedCallBackend (CeedVectorSetAllInvalid_Hip (vec ));
214263 switch (mem_type ) {
215264 case CEED_MEM_HOST :
216- return CeedVectorSetArrayHost_Hip (vec , copy_mode , array );
265+ if (impl -> has_unified_addressing ) {
266+ return CeedVectorSetArrayUnifiedHostToDevice_Hip (vec , copy_mode , array );
267+ } else {
268+ return CeedVectorSetArrayHost_Hip (vec , copy_mode , array );
269+ }
217270 case CEED_MEM_DEVICE :
218271 return CeedVectorSetArrayDevice_Hip (vec , copy_mode , array );
219272 }
@@ -303,8 +356,10 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
303356static int CeedVectorSetValue_Hip (CeedVector vec , CeedScalar val ) {
304357 CeedSize length ;
305358 CeedVector_Hip * impl ;
359+ Ceed_Hip * hip_data ;
306360
307361 CeedCallBackend (CeedVectorGetData (vec , & impl ));
362+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
308363 CeedCallBackend (CeedVectorGetLength (vec , & length ));
309364 // Set value for synced device/host array
310365 if (!impl -> d_array && !impl -> h_array ) {
@@ -321,7 +376,7 @@ static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
321376 }
322377 }
323378 if (impl -> d_array ) {
324- if (val == 0 ) {
379+ if (val == 0 && ! impl -> h_array_borrowed ) {
325380 CeedCallHip (CeedVectorReturnCeed (vec ), hipMemset (impl -> d_array , 0 , length * sizeof (CeedScalar )));
326381 } else {
327382 CeedCallBackend (CeedDeviceSetValue_Hip (impl -> d_array , length , val ));
@@ -398,14 +453,17 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
398453}
399454
400455//------------------------------------------------------------------------------
401- // Core logic for array syncronization for GetArray.
456+ // Core logic for array synchronization for GetArray.
402457// If a different memory type is most up to date, this will perform a copy
403458//------------------------------------------------------------------------------
404- static int CeedVectorGetArrayCore_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
459+ static int CeedVectorGetArrayCore_Hip (const CeedVector vec , CeedMemType mem_type , CeedScalar * * array ) {
405460 CeedVector_Hip * impl ;
406461
407462 CeedCallBackend (CeedVectorGetData (vec , & impl ));
408463
464+ // Use device memory for unified memory
465+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
466+
409467 // Sync array to requested mem_type
410468 CeedCallBackend (CeedVectorSyncArray (vec , mem_type ));
411469
@@ -431,10 +489,15 @@ static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType me
431489//------------------------------------------------------------------------------
432490// Get read/write access to a vector via the specified mem_type
433491//------------------------------------------------------------------------------
434- static int CeedVectorGetArray_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
492+ static int CeedVectorGetArray_Hip (const CeedVector vec , CeedMemType mem_type , CeedScalar * * array ) {
435493 CeedVector_Hip * impl ;
436494
437495 CeedCallBackend (CeedVectorGetData (vec , & impl ));
496+
497+ // Use device memory for unified memory
498+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
499+
500+ // 'Get' array and set only 'get'ed array as valid
438501 CeedCallBackend (CeedVectorGetArrayCore_Hip (vec , mem_type , array ));
439502 CeedCallBackend (CeedVectorSetAllInvalid_Hip (vec ));
440503 switch (mem_type ) {
@@ -451,11 +514,17 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
451514//------------------------------------------------------------------------------
452515// Get write access to a vector via the specified mem_type
453516//------------------------------------------------------------------------------
454- static int CeedVectorGetArrayWrite_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
517+ static int CeedVectorGetArrayWrite_Hip (const CeedVector vec , CeedMemType mem_type , CeedScalar * * array ) {
455518 bool has_array_of_type = true;
456519 CeedVector_Hip * impl ;
520+ Ceed_Hip * hip_data ;
457521
458522 CeedCallBackend (CeedVectorGetData (vec , & impl ));
523+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
524+
525+ // Use device memory for unified memory
526+ mem_type = impl -> has_unified_addressing && !impl -> h_array_borrowed ? CEED_MEM_DEVICE : mem_type ;
527+
459528 CeedCallBackend (CeedVectorHasArrayOfType_Hip (vec , mem_type , & has_array_of_type ));
460529 if (!has_array_of_type ) {
461530 // Allocate if array is not yet allocated
@@ -487,8 +556,10 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
487556 const CeedScalar * d_array ;
488557 CeedVector_Hip * impl ;
489558 hipblasHandle_t handle ;
559+ Ceed_Hip * hip_data ;
490560
491561 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
562+ CeedCallBackend (CeedGetData (ceed , & hip_data ));
492563 CeedCallBackend (CeedVectorGetData (vec , & impl ));
493564 CeedCallBackend (CeedVectorGetLength (vec , & length ));
494565 CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
@@ -518,7 +589,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
518589 CeedSize remaining_length = length - (CeedSize )(i )* INT_MAX ;
519590 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
520591
521- CeedCallHipblas (ceed , cublasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
592+ CeedCallHipblas (ceed , hipblasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
522593 * norm += sub_norm ;
523594 }
524595#endif /* HIP_VERSION */
@@ -545,7 +616,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
545616#if defined(CEED_SCALAR_IS_FP32 )
546617#if (HIP_VERSION >= 60000000 )
547618 CeedCallHipblas (ceed , hipblasSnrm2_64 (handle , (int64_t )length , (float * )d_array , 1 , (float * )norm ));
548- #else /* CUDA_VERSION */
619+ #else /* HIP_VERSION */
549620 float sub_norm = 0.0 , norm_sum = 0.0 ;
550621 float * d_array_start ;
551622
@@ -562,7 +633,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
562633#else /* CEED_SCALAR */
563634#if (HIP_VERSION >= 60000000 )
564635 CeedCallHipblas (ceed , hipblasDnrm2_64 (handle , (int64_t )length , (double * )d_array , 1 , (double * )norm ));
565- #else /* CUDA_VERSION */
636+ #else /* HIP_VERSION */
566637 double sub_norm = 0.0 , norm_sum = 0.0 ;
567638 double * d_array_start ;
568639
@@ -599,7 +670,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
599670 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
600671
601672 CeedCallHipblas (ceed , hipblasIsamax (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & index ));
602- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
673+ if (hip_data -> has_unified_addressing ) {
674+ CeedCallHip (ceed , hipDeviceSynchronize ());
675+ sub_max = fabs (d_array [index - 1 ]);
676+ } else {
677+ CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
678+ }
603679 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
604680 }
605681 * norm = current_max ;
@@ -610,7 +686,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
610686 CeedScalar norm_no_abs ;
611687
612688 CeedCallHipblas (ceed , hipblasIdamax_64 (handle , (int64_t )length , (double * )d_array , 1 , & index ));
613- CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
689+ if (hip_data -> has_unified_addressing ) {
690+ CeedCallHip (ceed , hipDeviceSynchronize ());
691+ norm_no_abs = fabs (d_array [index - 1 ]);
692+ } else {
693+ CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
694+ }
614695 * norm = fabs (norm_no_abs );
615696#else /* HIP_VERSION */
616697 CeedInt index ;
@@ -623,7 +704,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
623704 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
624705
625706 CeedCallHipblas (ceed , hipblasIdamax (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & index ));
626- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
707+ if (hip_data -> has_unified_addressing ) {
708+ CeedCallHip (ceed , hipDeviceSynchronize ());
709+ sub_max = fabs (d_array [index - 1 ]);
710+ } else {
711+ CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
712+ }
627713 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
628714 }
629715 * norm = current_max ;
@@ -854,6 +940,7 @@ static int CeedVectorDestroy_Hip(const CeedVector vec) {
854940//------------------------------------------------------------------------------
855941int CeedVectorCreate_Hip (CeedSize n , CeedVector vec ) {
856942 CeedVector_Hip * impl ;
943+ Ceed_Hip * hip_impl ;
857944 Ceed ceed ;
858945
859946 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
@@ -875,8 +962,10 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
875962 CeedCallBackend (CeedSetBackendFunction (ceed , "Vector" , vec , "AXPBY" , CeedVectorAXPBY_Hip ));
876963 CeedCallBackend (CeedSetBackendFunction (ceed , "Vector" , vec , "PointwiseMult" , CeedVectorPointwiseMult_Hip ));
877964 CeedCallBackend (CeedSetBackendFunction (ceed , "Vector" , vec , "Destroy" , CeedVectorDestroy_Hip ));
878- CeedCallBackend (CeedDestroy (& ceed ));
879965 CeedCallBackend (CeedCalloc (1 , & impl ));
966+ CeedCallBackend (CeedGetData (ceed , & hip_impl ));
967+ CeedCallBackend (CeedDestroy (& ceed ));
968+ impl -> has_unified_addressing = hip_impl -> has_unified_addressing ;
880969 CeedCallBackend (CeedVectorSetData (vec , impl ));
881970 return CEED_ERROR_SUCCESS ;
882971}
0 commit comments