@@ -95,7 +95,21 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
9595// Sync arrays
9696//------------------------------------------------------------------------------
9797static int CeedVectorSyncArray_Hip (const CeedVector vec , CeedMemType mem_type ) {
98- bool need_sync = false;
98+ bool need_sync = false;
99+ Ceed_Hip * hip_data ;
100+
101+ // Sync for unified memory
102+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
103+ if (hip_data -> has_unified_addressing ) {
104+ Ceed ceed ;
105+ CeedVector_Hip * impl ;
106+
107+ CeedCallBackend (CeedVectorGetData (vec , & impl ));
108+ CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
109+ CeedCallHip (ceed , hipDeviceSynchronize ());
110+ CeedCallBackend (CeedDestroy (& ceed ));
111+ return CEED_ERROR_SUCCESS ;
112+ }
99113
100114 // Check whether device/host sync is needed
101115 CeedCallBackend (CeedVectorNeedSync_Hip (vec , mem_type , & need_sync ));
@@ -156,8 +170,14 @@ static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType
156170//------------------------------------------------------------------------------
157171static inline int CeedVectorHasBorrowedArrayOfType_Hip (const CeedVector vec , CeedMemType mem_type , bool * has_borrowed_array_of_type ) {
158172 CeedVector_Hip * impl ;
173+ Ceed_Hip * hip_data ;
159174
160175 CeedCallBackend (CeedVectorGetData (vec , & impl ));
176+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
177+
178+ // Use device memory for unified memory
179+ mem_type = hip_data -> has_unified_addressing ? CEED_MEM_DEVICE : mem_type ;
180+
161181 switch (mem_type ) {
162182 case CEED_MEM_HOST :
163183 * has_borrowed_array_of_type = impl -> h_array_borrowed ;
@@ -303,8 +323,10 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
303323static int CeedVectorSetValue_Hip (CeedVector vec , CeedScalar val ) {
304324 CeedSize length ;
305325 CeedVector_Hip * impl ;
326+ Ceed_Hip * hip_data ;
306327
307328 CeedCallBackend (CeedVectorGetData (vec , & impl ));
329+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
308330 CeedCallBackend (CeedVectorGetLength (vec , & length ));
309331 // Set value for synced device/host array
310332 if (!impl -> d_array && !impl -> h_array ) {
@@ -403,8 +425,13 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
403425//------------------------------------------------------------------------------
404426static int CeedVectorGetArrayCore_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
405427 CeedVector_Hip * impl ;
428+ Ceed_Hip * hip_data ;
406429
407430 CeedCallBackend (CeedVectorGetData (vec , & impl ));
431+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
432+
433+ // Use device memory for unified memory
434+ mem_type = hip_data -> has_unified_addressing && impl -> d_array ? CEED_MEM_DEVICE : mem_type ;
408435
409436 // Sync array to requested mem_type
410437 CeedCallBackend (CeedVectorSyncArray (vec , mem_type ));
@@ -433,8 +460,15 @@ static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType me
433460//------------------------------------------------------------------------------
434461static int CeedVectorGetArray_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
435462 CeedVector_Hip * impl ;
463+ Ceed_Hip * hip_data ;
436464
437465 CeedCallBackend (CeedVectorGetData (vec , & impl ));
466+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
467+
468+ // Use device memory for unified memory
469+ mem_type = hip_data -> has_unified_addressing && impl -> d_array ? CEED_MEM_DEVICE : mem_type ;
470+
471+ // 'Get' array and set only 'get'ed array as valid
438472 CeedCallBackend (CeedVectorGetArrayCore_Hip (vec , mem_type , array ));
439473 CeedCallBackend (CeedVectorSetAllInvalid_Hip (vec ));
440474 switch (mem_type ) {
@@ -454,8 +488,14 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
454488static int CeedVectorGetArrayWrite_Hip (const CeedVector vec , const CeedMemType mem_type , CeedScalar * * array ) {
455489 bool has_array_of_type = true;
456490 CeedVector_Hip * impl ;
491+ Ceed_Hip * hip_data ;
457492
458493 CeedCallBackend (CeedVectorGetData (vec , & impl ));
494+ CeedCallBackend (CeedGetData (CeedVectorReturnCeed (vec ), & hip_data ));
495+
496+ // Use device memory for unified memory
497+ mem_type = hip_data -> has_unified_addressing && impl -> d_array ? CEED_MEM_DEVICE : mem_type ;
498+
459499 CeedCallBackend (CeedVectorHasArrayOfType_Hip (vec , mem_type , & has_array_of_type ));
460500 if (!has_array_of_type ) {
461501 // Allocate if array is not yet allocated
@@ -487,8 +527,10 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
487527 const CeedScalar * d_array ;
488528 CeedVector_Hip * impl ;
489529 hipblasHandle_t handle ;
530+ Ceed_Hip * hip_data ;
490531
491532 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
533+ CeedCallBackend (CeedGetData (ceed , & hip_data ));
492534 CeedCallBackend (CeedVectorGetData (vec , & impl ));
493535 CeedCallBackend (CeedVectorGetLength (vec , & length ));
494536 CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
@@ -518,7 +560,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
518560 CeedSize remaining_length = length - (CeedSize )(i )* INT_MAX ;
519561 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
520562
521- CeedCallHipblas (ceed , cublasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
563+ CeedCallHipblas (ceed , hipblasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
522564 * norm += sub_norm ;
523565 }
524566#endif /* HIP_VERSION */
@@ -545,7 +587,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
545587#if defined(CEED_SCALAR_IS_FP32 )
546588#if (HIP_VERSION >= 60000000 )
547589 CeedCallHipblas (ceed , hipblasSnrm2_64 (handle , (int64_t )length , (float * )d_array , 1 , (float * )norm ));
548- #else /* CUDA_VERSION */
590+ #else /* HIP_VERSION */
549591 float sub_norm = 0.0 , norm_sum = 0.0 ;
550592 float * d_array_start ;
551593
@@ -562,7 +604,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
562604#else /* CEED_SCALAR */
563605#if (HIP_VERSION >= 60000000 )
564606 CeedCallHipblas (ceed , hipblasDnrm2_64 (handle , (int64_t )length , (double * )d_array , 1 , (double * )norm ));
565- #else /* CUDA_VERSION */
607+ #else /* HIP_VERSION */
566608 double sub_norm = 0.0 , norm_sum = 0.0 ;
567609 double * d_array_start ;
568610
@@ -599,7 +641,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
599641 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
600642
601643 CeedCallHipblas (ceed , hipblasIsamax (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & index ));
602- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
644+ if ((hip_data -> has_unified_addressing ) {
645+ CeedCallHip (ceed , hipDeviceSynchronize ());
646+ sub_max = fabs (d_array [index - 1 ]);
647+ } else {
648+ CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
649+ }
603650 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
604651 }
605652 * norm = current_max ;
@@ -610,7 +657,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
610657 CeedScalar norm_no_abs ;
611658
612659 CeedCallHipblas (ceed , hipblasIdamax_64 (handle , (int64_t )length , (double * )d_array , 1 , & index ));
613- CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
660+ if ((hip_data -> has_unified_addressing ) {
661+ CeedCallHip (ceed , hipDeviceSynchronize ());
662+ norm_no_abs = fabs (d_array [index - 1 ]);
663+ } else {
664+ CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
665+ }
614666 * norm = fabs (norm_no_abs );
615667#else /* HIP_VERSION */
616668 CeedInt index ;
@@ -623,7 +675,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
623675 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
624676
625677 CeedCallHipblas (ceed , hipblasIdamax (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & index ));
626- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
678+ if ((hip_data -> has_unified_addressing ) {
679+ CeedCallHip (ceed , hipDeviceSynchronize ());
680+ sub_max = fabs (d_array [index - 1 ]);
681+ } else {
682+ CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
683+ }
627684 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
628685 }
629686 * norm = current_max ;
0 commit comments