Skip to content

Commit a6008e9

Browse files
committed
hip - use unified memory
co-authored-by: Zach Atkins <zach.atkins@colorado.edu>
1 parent a637ca9 commit a6008e9

File tree

3 files changed

+73
-8
lines changed

3 files changed

+73
-8
lines changed

backends/hip-ref/ceed-hip-ref-vector.c

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,21 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
9595
// Sync arrays
9696
//------------------------------------------------------------------------------
9797
static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
98-
bool need_sync = false;
98+
bool need_sync = false;
99+
Ceed_Hip *hip_data;
100+
101+
// Sync for unified memory
102+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
103+
if (hip_data->has_unified_addressing) {
104+
Ceed ceed;
105+
CeedVector_Hip *impl;
106+
107+
CeedCallBackend(CeedVectorGetData(vec, &impl));
108+
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
109+
CeedCallHip(ceed, hipDeviceSynchronize());
110+
CeedCallBackend(CeedDestroy(&ceed));
111+
return CEED_ERROR_SUCCESS;
112+
}
99113

100114
// Check whether device/host sync is needed
101115
CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
@@ -156,8 +170,14 @@ static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType
156170
//------------------------------------------------------------------------------
157171
static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
158172
CeedVector_Hip *impl;
173+
Ceed_Hip *hip_data;
159174

160175
CeedCallBackend(CeedVectorGetData(vec, &impl));
176+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
177+
178+
// Use device memory for unified memory
179+
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;
180+
161181
switch (mem_type) {
162182
case CEED_MEM_HOST:
163183
*has_borrowed_array_of_type = impl->h_array_borrowed;
@@ -303,8 +323,10 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
303323
static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
304324
CeedSize length;
305325
CeedVector_Hip *impl;
326+
Ceed_Hip *hip_data;
306327

307328
CeedCallBackend(CeedVectorGetData(vec, &impl));
329+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
308330
CeedCallBackend(CeedVectorGetLength(vec, &length));
309331
// Set value for synced device/host array
310332
if (!impl->d_array && !impl->h_array) {
@@ -403,8 +425,13 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
403425
//------------------------------------------------------------------------------
404426
static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
405427
CeedVector_Hip *impl;
428+
Ceed_Hip *hip_data;
406429

407430
CeedCallBackend(CeedVectorGetData(vec, &impl));
431+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
432+
433+
// Use device memory for unified memory
434+
mem_type = hip_data->has_unified_addressing && impl->d_array ? CEED_MEM_DEVICE : mem_type;
408435

409436
// Sync array to requested mem_type
410437
CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
@@ -433,8 +460,15 @@ static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType me
433460
//------------------------------------------------------------------------------
434461
static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
435462
CeedVector_Hip *impl;
463+
Ceed_Hip *hip_data;
436464

437465
CeedCallBackend(CeedVectorGetData(vec, &impl));
466+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
467+
468+
// Use device memory for unified memory
469+
mem_type = hip_data->has_unified_addressing && impl->d_array ? CEED_MEM_DEVICE : mem_type;
470+
471+
// 'Get' array and set only 'get'ed array as valid
438472
CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
439473
CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
440474
switch (mem_type) {
@@ -454,8 +488,14 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
454488
static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
455489
bool has_array_of_type = true;
456490
CeedVector_Hip *impl;
491+
Ceed_Hip *hip_data;
457492

458493
CeedCallBackend(CeedVectorGetData(vec, &impl));
494+
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
495+
496+
// Use device memory for unified memory
497+
mem_type = hip_data->has_unified_addressing && impl->d_array ? CEED_MEM_DEVICE : mem_type;
498+
459499
CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
460500
if (!has_array_of_type) {
461501
// Allocate if array is not yet allocated
@@ -487,8 +527,10 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
487527
const CeedScalar *d_array;
488528
CeedVector_Hip *impl;
489529
hipblasHandle_t handle;
530+
Ceed_Hip *hip_data;
490531

491532
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
533+
CeedCallBackend(CeedGetData(ceed, &hip_data));
492534
CeedCallBackend(CeedVectorGetData(vec, &impl));
493535
CeedCallBackend(CeedVectorGetLength(vec, &length));
494536
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
@@ -518,7 +560,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
518560
CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
519561
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
520562

521-
CeedCallHipblas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
563+
CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
522564
*norm += sub_norm;
523565
}
524566
#endif /* HIP_VERSION */
@@ -545,7 +587,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
545587
#if defined(CEED_SCALAR_IS_FP32)
546588
#if (HIP_VERSION >= 60000000)
547589
CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
548-
#else /* CUDA_VERSION */
590+
#else /* HIP_VERSION */
549591
float sub_norm = 0.0, norm_sum = 0.0;
550592
float *d_array_start;
551593

@@ -562,7 +604,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
562604
#else /* CEED_SCALAR */
563605
#if (HIP_VERSION >= 60000000)
564606
CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
565-
#else /* CUDA_VERSION */
607+
#else /* HIP_VERSION */
566608
double sub_norm = 0.0, norm_sum = 0.0;
567609
double *d_array_start;
568610

@@ -599,7 +641,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
599641
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
600642

601643
CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
602-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
644+
if ((hip_data->has_unified_addressing) {
645+
CeedCallHip(ceed, hipDeviceSynchronize());
646+
sub_max = fabs(d_array[index - 1]);
647+
} else {
648+
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
649+
}
603650
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
604651
}
605652
*norm = current_max;
@@ -610,7 +657,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
610657
CeedScalar norm_no_abs;
611658

612659
CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
613-
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
660+
if ((hip_data->has_unified_addressing) {
661+
CeedCallHip(ceed, hipDeviceSynchronize());
662+
norm_no_abs = fabs(d_array[index - 1]);
663+
} else {
664+
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
665+
}
614666
*norm = fabs(norm_no_abs);
615667
#else /* HIP_VERSION */
616668
CeedInt index;
@@ -623,7 +675,12 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
623675
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
624676

625677
CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
626-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
678+
if ((hip_data->has_unified_addressing) {
679+
CeedCallHip(ceed, hipDeviceSynchronize());
680+
sub_max = fabs(d_array[index - 1]);
681+
} else {
682+
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
683+
}
627684
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
628685
}
629686
*norm = current_max;

backends/hip/ceed-hip-common.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ int CeedInit_Hip(Ceed ceed, const char *resource) {
1919
Ceed_Hip *data;
2020
const char *device_spec = strstr(resource, ":device_id=");
2121
const int device_id = (device_spec) ? atoi(device_spec + 11) : -1;
22-
int current_device_id;
22+
int current_device_id, xnack_value;
23+
const char *xnack;
2324

2425
CeedCallHip(ceed, hipGetDevice(&current_device_id));
2526
if (device_id >= 0 && current_device_id != device_id) {
@@ -30,6 +31,12 @@ int CeedInit_Hip(Ceed ceed, const char *resource) {
3031
CeedCallBackend(CeedGetData(ceed, &data));
3132
data->device_id = current_device_id;
3233
CeedCallHip(ceed, hipGetDeviceProperties(&data->device_prop, current_device_id));
34+
xnack = getenv("HSA_XNACK");
35+
xnack_value = !!xnack ? atol(xnack) : 0;
36+
data->has_unified_addressing = xnack_value > 0 ? data->device_prop.unifiedAddressing : 0;
37+
if (data->has_unified_addressing) {
38+
CeedDebug(ceed, "Using unified memory addressing");
39+
}
3340
data->opt_block_size = 256;
3441
return CEED_ERROR_SUCCESS;
3542
}

backends/hip/ceed-hip-common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ typedef struct {
7272
hipblasHandle_t hipblas_handle;
7373
struct hipDeviceProp_t device_prop;
7474
int opt_block_size;
75+
int has_unified_addressing;
7576
} Ceed_Hip;
7677

7778
CEED_INTERN int CeedInit_Hip(Ceed ceed, const char *resource);

0 commit comments

Comments
 (0)