Skip to content

Commit b314491

Browse files
committed
hip - use unified memory
Co-authored-by: Zach Atkins <zach.atkins@colorado.edu>
1 parent a637ca9 commit b314491

File tree

4 files changed

+113
-15
lines changed

4 files changed

+113
-15
lines changed

backends/hip-ref/ceed-hip-ref-vector.c

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,15 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
9595
// Sync arrays
9696
//------------------------------------------------------------------------------
9797
static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
98-
bool need_sync = false;
98+
bool need_sync = false;
99+
CeedVector_Hip *impl;
100+
101+
// Sync for unified memory
102+
CeedCallBackend(CeedVectorGetData(vec, &impl));
103+
if (impl->has_unified_addressing && !impl->h_array_borrowed) {
104+
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
105+
return CEED_ERROR_SUCCESS;
106+
}
99107

100108
// Check whether device/host sync is needed
101109
CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
@@ -158,6 +166,10 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, Cee
158166
CeedVector_Hip *impl;
159167

160168
CeedCallBackend(CeedVectorGetData(vec, &impl));
169+
170+
// Use device memory for unified memory
171+
mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
172+
161173
switch (mem_type) {
162174
case CEED_MEM_HOST:
163175
*has_borrowed_array_of_type = impl->h_array_borrowed;
@@ -202,6 +214,43 @@ static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode
202214
return CEED_ERROR_SUCCESS;
203215
}
204216

217+
//------------------------------------------------------------------------------
218+
// Set array with unified memory
219+
//------------------------------------------------------------------------------
220+
static int CeedVectorSetArrayUnifiedHostToDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
221+
CeedSize length;
222+
Ceed ceed;
223+
CeedVector_Hip *impl;
224+
225+
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
226+
CeedCallBackend(CeedVectorGetData(vec, &impl));
227+
CeedCallBackend(CeedVectorGetLength(vec, &length));
228+
229+
switch (copy_mode) {
230+
case CEED_COPY_VALUES:
231+
case CEED_OWN_POINTER:
232+
if (!impl->d_array) {
233+
if (impl->d_array_borrowed) {
234+
impl->d_array = impl->d_array_borrowed;
235+
} else {
236+
if (!impl->d_array_owned) CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, sizeof(CeedScalar) * length));
237+
impl->d_array = impl->d_array_owned;
238+
}
239+
}
240+
if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array, array, sizeof(CeedScalar) * length, hipMemcpyHostToDevice));
241+
if (copy_mode == CEED_OWN_POINTER) CeedCallBackend(CeedFree(&array));
242+
break;
243+
case CEED_USE_POINTER:
244+
CeedCallHip(ceed, hipFree(impl->d_array_owned));
245+
CeedCallBackend(CeedFree(&impl->h_array_owned));
246+
impl->h_array_owned = NULL;
247+
impl->h_array_borrowed = array;
248+
impl->d_array = impl->h_array_borrowed;
249+
}
250+
CeedCallBackend(CeedDestroy(&ceed));
251+
return CEED_ERROR_SUCCESS;
252+
}
253+
205254
//------------------------------------------------------------------------------
206255
// Set the array used by a vector,
207256
// freeing any previously allocated array if applicable
@@ -213,7 +262,11 @@ static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
213262
CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
214263
switch (mem_type) {
215264
case CEED_MEM_HOST:
216-
return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
265+
if (impl->has_unified_addressing) {
266+
return CeedVectorSetArrayUnifiedHostToDevice_Hip(vec, copy_mode, array);
267+
} else {
268+
return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
269+
}
217270
case CEED_MEM_DEVICE:
218271
return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
219272
}
@@ -303,8 +356,10 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
303356
static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
304357
CeedSize length;
305358
CeedVector_Hip *impl;
359+
Ceed_Hip *hip_data;
306360

307361
CeedCallBackend(CeedVectorGetData(vec, &impl));
362+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
308363
CeedCallBackend(CeedVectorGetLength(vec, &length));
309364
// Set value for synced device/host array
310365
if (!impl->d_array && !impl->h_array) {
@@ -321,7 +376,7 @@ static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
321376
}
322377
}
323378
if (impl->d_array) {
324-
if (val == 0) {
379+
if (val == 0 && !impl->h_array_borrowed) {
325380
CeedCallHip(CeedVectorReturnCeed(vec), hipMemset(impl->d_array, 0, length * sizeof(CeedScalar)));
326381
} else {
327382
CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
@@ -398,14 +453,17 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
398453
}
399454

400455
//------------------------------------------------------------------------------
401-
// Core logic for array syncronization for GetArray.
456+
// Core logic for array synchronization for GetArray.
402457
// If a different memory type is most up to date, this will perform a copy
403458
//------------------------------------------------------------------------------
404-
static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
459+
static int CeedVectorGetArrayCore_Hip(const CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
405460
CeedVector_Hip *impl;
406461

407462
CeedCallBackend(CeedVectorGetData(vec, &impl));
408463

464+
// Use device memory for unified memory
465+
mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
466+
409467
// Sync array to requested mem_type
410468
CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
411469

@@ -431,10 +489,15 @@ static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType me
431489
//------------------------------------------------------------------------------
432490
// Get read/write access to a vector via the specified mem_type
433491
//------------------------------------------------------------------------------
434-
static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
492+
static int CeedVectorGetArray_Hip(const CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
435493
CeedVector_Hip *impl;
436494

437495
CeedCallBackend(CeedVectorGetData(vec, &impl));
496+
497+
// Use device memory for unified memory
498+
mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
499+
500+
// 'Get' array and set only 'get'ed array as valid
438501
CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
439502
CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
440503
switch (mem_type) {
@@ -451,11 +514,17 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
451514
//------------------------------------------------------------------------------
452515
// Get write access to a vector via the specified mem_type
453516
//------------------------------------------------------------------------------
454-
static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
517+
static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
455518
bool has_array_of_type = true;
456519
CeedVector_Hip *impl;
520+
Ceed_Hip *hip_data;
457521

458522
CeedCallBackend(CeedVectorGetData(vec, &impl));
523+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
524+
525+
// Use device memory for unified memory
526+
mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
527+
459528
CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
460529
if (!has_array_of_type) {
461530
// Allocate if array is not yet allocated
@@ -487,8 +556,10 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
487556
const CeedScalar *d_array;
488557
CeedVector_Hip *impl;
489558
hipblasHandle_t handle;
559+
Ceed_Hip *hip_data;
490560

491561
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
562+
CeedCallBackend(CeedGetData(ceed, &hip_data));
492563
CeedCallBackend(CeedVectorGetData(vec, &impl));
493564
CeedCallBackend(CeedVectorGetLength(vec, &length));
494565
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
@@ -518,7 +589,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
518589
CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
519590
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
520591

521-
CeedCallHipblas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
592+
CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
522593
*norm += sub_norm;
523594
}
524595
#endif /* HIP_VERSION */
@@ -545,7 +616,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
545616
#if defined(CEED_SCALAR_IS_FP32)
546617
#if (HIP_VERSION >= 60000000)
547618
CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
548-
#else /* CUDA_VERSION */
619+
#else /* HIP_VERSION */
549620
float sub_norm = 0.0, norm_sum = 0.0;
550621
float *d_array_start;
551622

@@ -562,7 +633,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
562633
#else /* CEED_SCALAR */
563634
#if (HIP_VERSION >= 60000000)
564635
CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
565-
#else /* CUDA_VERSION */
636+
#else /* HIP_VERSION */
566637
double sub_norm = 0.0, norm_sum = 0.0;
567638
double *d_array_start;
568639

@@ -599,7 +670,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
599670
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
600671

601672
CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
602-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
673+
if (hip_data->has_unified_addressing) {
674+
CeedCallHip(ceed, hipDeviceSynchronize());
675+
sub_max = fabs(d_array[index - 1]);
676+
} else {
677+
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
678+
}
603679
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
604680
}
605681
*norm = current_max;
@@ -610,7 +686,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
610686
CeedScalar norm_no_abs;
611687

612688
CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
613-
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
689+
if (hip_data->has_unified_addressing) {
690+
CeedCallHip(ceed, hipDeviceSynchronize());
691+
norm_no_abs = fabs(d_array[index - 1]);
692+
} else {
693+
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
694+
}
614695
*norm = fabs(norm_no_abs);
615696
#else /* HIP_VERSION */
616697
CeedInt index;
@@ -623,7 +704,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
623704
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
624705

625706
CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
626-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
707+
if (hip_data->has_unified_addressing) {
708+
CeedCallHip(ceed, hipDeviceSynchronize());
709+
sub_max = fabs(d_array[index - 1]);
710+
} else {
711+
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
712+
}
627713
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
628714
}
629715
*norm = current_max;
@@ -854,6 +940,7 @@ static int CeedVectorDestroy_Hip(const CeedVector vec) {
854940
//------------------------------------------------------------------------------
855941
int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
856942
CeedVector_Hip *impl;
943+
Ceed_Hip *hip_impl;
857944
Ceed ceed;
858945

859946
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
@@ -875,8 +962,10 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
875962
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
876963
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
877964
CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
878-
CeedCallBackend(CeedDestroy(&ceed));
879965
CeedCallBackend(CeedCalloc(1, &impl));
966+
CeedCallBackend(CeedGetData(ceed, &hip_impl));
967+
CeedCallBackend(CeedDestroy(&ceed));
968+
impl->has_unified_addressing = hip_impl->has_unified_addressing;
880969
CeedCallBackend(CeedVectorSetData(vec, impl));
881970
return CEED_ERROR_SUCCESS;
882971
}

backends/hip-ref/ceed-hip-ref.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#endif
1818

1919
typedef struct {
20+
int *has_unified_addressing;
2021
CeedScalar *h_array;
2122
CeedScalar *h_array_borrowed;
2223
CeedScalar *h_array_owned;

backends/hip/ceed-hip-common.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ int CeedInit_Hip(Ceed ceed, const char *resource) {
1919
Ceed_Hip *data;
2020
const char *device_spec = strstr(resource, ":device_id=");
2121
const int device_id = (device_spec) ? atoi(device_spec + 11) : -1;
22-
int current_device_id;
22+
int current_device_id, xnack_value;
23+
const char *xnack;
2324

2425
CeedCallHip(ceed, hipGetDevice(&current_device_id));
2526
if (device_id >= 0 && current_device_id != device_id) {
@@ -30,6 +31,12 @@ int CeedInit_Hip(Ceed ceed, const char *resource) {
3031
CeedCallBackend(CeedGetData(ceed, &data));
3132
data->device_id = current_device_id;
3233
CeedCallHip(ceed, hipGetDeviceProperties(&data->device_prop, current_device_id));
34+
xnack = getenv("HSA_XNACK");
35+
xnack_value = !!xnack ? atol(xnack) : 0;
36+
data->has_unified_addressing = xnack_value > 0 ? data->device_prop.unifiedAddressing : 0;
37+
if (data->has_unified_addressing) {
38+
CeedDebug(ceed, "Using unified memory addressing");
39+
}
3340
data->opt_block_size = 256;
3441
return CEED_ERROR_SUCCESS;
3542
}

backends/hip/ceed-hip-common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ typedef struct {
7272
hipblasHandle_t hipblas_handle;
7373
struct hipDeviceProp_t device_prop;
7474
int opt_block_size;
75+
int has_unified_addressing;
7576
} Ceed_Hip;
7677

7778
CEED_INTERN int CeedInit_Hip(Ceed ceed, const char *resource);

0 commit comments

Comments
 (0)