Skip to content

Commit f1c05e9

Browse files
author
Sourabh Betigeri
committed
SWDEV-421020 - Adds hipGraphAddBatchMemOp, SetGetParams and execSetParams APIs
Change-Id: Ieccecfe6173cc68fd3c01f86c99f7cc09fe194a3
1 parent 93f1e8f commit f1c05e9

File tree

8 files changed

+294
-8
lines changed

8 files changed

+294
-8
lines changed

hipamd/include/hip/amd_detail/hip_api_trace.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
// - Reset any of the *_STEP_VERSION defines to zero if the corresponding *_MAJOR_VERSION increases
6262
#define HIP_API_TABLE_STEP_VERSION 0
6363
#define HIP_COMPILER_API_TABLE_STEP_VERSION 0
64-
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 7
64+
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 8
6565

6666
// HIP API interface
6767
typedef hipError_t (*t___hipPopCallConfiguration)(dim3* gridDim, dim3* blockDim, size_t* sharedMem,
@@ -723,7 +723,8 @@ typedef hipError_t (*t_hipStreamWriteValue32)(hipStream_t stream, void* ptr, uin
723723
typedef hipError_t (*t_hipStreamWriteValue64)(hipStream_t stream, void* ptr, uint64_t value,
724724
unsigned int flags);
725725
typedef hipError_t (*t_hipStreamBatchMemOp)(hipStream_t stream, unsigned int count,
726-
hipStreamBatchMemOpParams* paramArray, unsigned int flags);
726+
hipStreamBatchMemOpParams* paramArray,
727+
unsigned int flags);
727728
typedef hipError_t (*t_hipTexObjectCreate)(hipTextureObject_t* pTexObject,
728729
const HIP_RESOURCE_DESC* pResDesc,
729730
const HIP_TEXTURE_DESC* pTexDesc,
@@ -1006,6 +1007,17 @@ typedef hipError_t (*t_hipExtHostAlloc)(void **ptr, size_t size,
10061007
typedef hipError_t (*t_hipDeviceGetTexture1DLinearMaxWidth)(size_t *maxWidthInElements,
10071008
const hipChannelFormatDesc *fmtDesc,
10081009
int device);
1010+
1011+
typedef hipError_t (*t_hipGraphAddBatchMemOpNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
1012+
const hipGraphNode_t* dependencies,
1013+
size_t numDependencies,
1014+
const hipBatchMemOpNodeParams* nodeParams);
1015+
typedef hipError_t (*t_hipGraphBatchMemOpNodeGetParams)(hipGraphNode_t hNode,
1016+
hipBatchMemOpNodeParams* nodeParams_out);
1017+
typedef hipError_t (*t_hipGraphBatchMemOpNodeSetParams)(hipGraphNode_t hNode,
1018+
hipBatchMemOpNodeParams* nodeParams);
1019+
typedef hipError_t (*t_hipGraphExecBatchMemOpNodeSetParams)(
1020+
hipGraphExec_t hGraphExec, hipGraphNode_t hNode, const hipBatchMemOpNodeParams* nodeParams);
10091021
// HIP Compiler dispatch table
10101022
struct HipCompilerDispatchTable {
10111023
// HIP_COMPILER_API_TABLE_STEP_VERSION == 0
@@ -1524,6 +1536,12 @@ struct HipDispatchTable {
15241536
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 7
15251537
t_hipStreamBatchMemOp hipStreamBatchMemOp_fn;
15261538

1539+
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 8
1540+
t_hipGraphAddBatchMemOpNode hipGraphAddBatchMemOpNode_fn;
1541+
t_hipGraphBatchMemOpNodeGetParams hipGraphBatchMemOpNodeGetParams_fn;
1542+
t_hipGraphBatchMemOpNodeSetParams hipGraphBatchMemOpNodeSetParams_fn;
1543+
t_hipGraphExecBatchMemOpNodeSetParams hipGraphExecBatchMemOpNodeSetParams_fn;
1544+
15271545
// DO NOT EDIT ABOVE!
15281546
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 7
15291547

hipamd/include/hip/amd_detail/hip_prof_str.h

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,11 @@ enum hip_api_id_t {
426426
HIP_API_ID_hipSetValidDevices = 406,
427427
HIP_API_ID_hipExtHostAlloc = 407,
428428
HIP_API_ID_hipStreamBatchMemOp = 408,
429-
HIP_API_ID_LAST = 408,
429+
HIP_API_ID_hipGraphAddBatchMemOpNode = 409,
430+
HIP_API_ID_hipGraphBatchMemOpNodeGetParams = 410,
431+
HIP_API_ID_hipGraphBatchMemOpNodeSetParams = 411,
432+
HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams = 412,
433+
HIP_API_ID_LAST = 412,
430434

431435
HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_,hipChooseDevice),
432436
HIP_API_ID_hipGetDeviceProperties = HIP_API_ID_CONCAT(HIP_API_ID_,hipGetDeviceProperties),
@@ -861,6 +865,10 @@ static inline const char* hip_api_name(const uint32_t id) {
861865
case HIP_API_ID_hipWaitExternalSemaphoresAsync: return "hipWaitExternalSemaphoresAsync";
862866
case HIP_API_ID_hipExtGetLastError: return "hipExtGetLastError";
863867
case HIP_API_ID_hipStreamBatchMemOp: return "hipStreamBatchMemOp";
868+
case HIP_API_ID_hipGraphAddBatchMemOpNode: return "hipGraphAddBatchMemOpNode";
869+
case HIP_API_ID_hipGraphBatchMemOpNodeGetParams: return "hipGraphBatchMemOpNodeGetParams";
870+
case HIP_API_ID_hipGraphBatchMemOpNodeSetParams: return "hipGraphBatchMemOpNodeSetParams";
871+
case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams: return "hipGraphExecBatchMemOpNodeSetParams";
864872
};
865873
return "unknown";
866874
};
@@ -1265,6 +1273,10 @@ static inline uint32_t hipApiIdByName(const char* name) {
12651273
if (strcmp("hipUserObjectRetain", name) == 0) return HIP_API_ID_hipUserObjectRetain;
12661274
if (strcmp("hipWaitExternalSemaphoresAsync", name) == 0) return HIP_API_ID_hipWaitExternalSemaphoresAsync;
12671275
if (strcmp("hipStreamBatchMemOp", name) == 0) return HIP_API_ID_hipStreamBatchMemOp;
1276+
if (strcmp("hipGraphAddBatchMemOpNode", name) == 0) return HIP_API_ID_hipGraphAddBatchMemOpNode;
1277+
if (strcmp("hipGraphBatchMemOpNodeGetParams", name) == 0) return HIP_API_ID_hipGraphBatchMemOpNodeGetParams;
1278+
if (strcmp("hipGraphBatchMemOpNodeSetParams", name) == 0) return HIP_API_ID_hipGraphBatchMemOpNodeSetParams;
1279+
if (strcmp("hipGraphExecBatchMemOpNodeSetParams", name) == 0) return HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams;
12681280
return HIP_API_ID_NONE;
12691281
}
12701282

@@ -3633,6 +3645,32 @@ typedef struct hip_api_data_s {
36333645
hipStreamBatchMemOpParams paramArray__val;
36343646
unsigned int flags;
36353647
} hipStreamBatchMemOp;
3648+
struct {
3649+
hipGraphNode_t* phGraphNode;
3650+
hipGraphNode_t phGraphNode__val;
3651+
hipGraph_t hGraph;
3652+
const hipGraphNode_t* dependencies;
3653+
hipGraphNode_t dependencies__val;
3654+
size_t numDependencies;
3655+
const hipBatchMemOpNodeParams* nodeParams;
3656+
hipBatchMemOpNodeParams nodeParams__val;
3657+
} hipGraphAddBatchMemOpNode;
3658+
struct {
3659+
hipGraphNode_t hNode;
3660+
hipBatchMemOpNodeParams* nodeParams_out;
3661+
hipBatchMemOpNodeParams nodeParams_out__val;
3662+
} hipGraphBatchMemOpNodeGetParams;
3663+
struct {
3664+
hipGraphNode_t hNode;
3665+
hipBatchMemOpNodeParams* nodeParams;
3666+
hipBatchMemOpNodeParams nodeParams__val;
3667+
} hipGraphBatchMemOpNodeSetParams;
3668+
struct {
3669+
hipGraphExec_t hGraphExec;
3670+
hipGraphNode_t hNode;
3671+
const hipBatchMemOpNodeParams* nodeParams;
3672+
hipBatchMemOpNodeParams nodeParams__val;
3673+
} hipGraphExecBatchMemOpNodeSetParams;
36363674
} args;
36373675
uint64_t *phase_data;
36383676
} hip_api_data_t;
@@ -6045,6 +6083,36 @@ typedef struct hip_api_data_s {
60456083
cb_data.args.hipWaitExternalSemaphoresAsync.numExtSems = (unsigned int)numExtSems; \
60466084
cb_data.args.hipWaitExternalSemaphoresAsync.stream = (hipStream_t)stream; \
60476085
};
6086+
6087+
// hipGraphAddBatchMemOpNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'),
6088+
// ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'),
6089+
// ('hipBatchMemOpNodeParams*'), 'nodeParams')]
6090+
#define INIT_hipGraphAddBatchMemOpNode_CB_ARGS_DATA(cb_data) { \
6091+
cb_data.args.hipGraphAddBatchMemOpNode.phGraphNode = (hipGraphNode_t*)phGraphNode; \
6092+
cb_data.args.hipGraphAddBatchMemOpNode.hGraph = (hipGraph_t)hGraph; \
6093+
cb_data.args.hipGraphAddBatchMemOpNode.dependencies= (hipGraphNode_t*)dependencies; \
6094+
cb_data.args.hipGraphAddBatchMemOpNode.numDependencies = (size_t)numDependencies; \
6095+
cb_data.args.hipGraphAddBatchMemOpNode.nodeParams = (hipBatchMemOpNodeParams*)nodeParams; \
6096+
};
6097+
// hipGraphBatchMemOpNodeGetParams[('hipGraphNode_t', hNode),
6098+
// ('hipBatchMemOpNodeParams*', 'nodeParams_out')]
6099+
#define INIT_hipGraphBatchMemOpNodeGetParams_CB_ARGS_DATA(cb_data) { \
6100+
cb_data.args.hipGraphBatchMemOpNodeGetParams.hNode = (hipGraphNode_t)hNode; \
6101+
cb_data.args.hipGraphBatchMemOpNodeGetParams.nodeParams_out = (hipBatchMemOpNodeParams*)nodeParams_out; \
6102+
};
6103+
// hipGraphBatchMemOpNodeSetParams[('hipGraphNode_t', hNode),
6104+
// ('hipBatchMemOpNodeParams*', 'nodeParams')]
6105+
#define INIT_hipGraphBatchMemOpNodeSetParams_CB_ARGS_DATA(cb_data) { \
6106+
cb_data.args.hipGraphBatchMemOpNodeSetParams.hNode = (hipGraphNode_t)hNode; \
6107+
cb_data.args.hipGraphBatchMemOpNodeSetParams.nodeParams = (hipBatchMemOpNodeParams*)nodeParams; \
6108+
};
6109+
// hipGraphExecBatchMemOpNodeSetParams[('hipGraphExec_t'. hGraphExec),
6110+
// ('hipGraphNode_t'. hNode), ('hipBatchMemOpNodeParams*', 'nodeParams')]
6111+
#define INIT_hipGraphExecBatchMemOpNodeSetParams_CB_ARGS_DATA(cb_data) { \
6112+
cb_data.args.hipGraphExecBatchMemOpNodeSetParams.hGraphExec = (hipGraphExec_t)hGraphExec; \
6113+
cb_data.args.hipGraphExecBatchMemOpNodeSetParams.hNode = (hipGraphNode_t)hNode; \
6114+
cb_data.args.hipGraphExecBatchMemOpNodeSetParams.nodeParams= (hipBatchMemOpNodeParams*)nodeParams; \
6115+
};
60486116
#define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data)
60496117

60506118
// Macros for non-public API primitives
@@ -7551,6 +7619,29 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
75517619
case HIP_API_ID_hipStreamBatchMemOp:
75527620
if (data->args.hipStreamBatchMemOp.paramArray) data->args.hipStreamBatchMemOp.paramArray__val = *(data->args.hipStreamBatchMemOp.paramArray);
75537621
break;
7622+
// hipGraphAddBatchMemOpNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'),
7623+
// ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'),
7624+
// ('hipBatchMemOpNodeParams*'), 'nodeParams')]
7625+
case HIP_API_ID_hipGraphAddBatchMemOpNode:
7626+
if (data->args.hipGraphAddBatchMemOpNode.phGraphNode) data->args.hipGraphAddBatchMemOpNode.phGraphNode__val = *(data->args.hipGraphAddBatchMemOpNode.phGraphNode);
7627+
if (data->args.hipGraphAddBatchMemOpNode.dependencies) data->args.hipGraphAddBatchMemOpNode.dependencies__val = *(data->args.hipGraphAddBatchMemOpNode.dependencies);
7628+
if (data->args.hipGraphAddBatchMemOpNode.nodeParams) data->args.hipGraphAddBatchMemOpNode.nodeParams__val = *(data->args.hipGraphAddBatchMemOpNode.nodeParams);
7629+
break;
7630+
// hipGraphBatchMemOpNodeGetParams[('hipGraphNode_t', hNode),
7631+
// ('hipBatchMemOpNodeParams*', 'nodeParams_out')]
7632+
case HIP_API_ID_hipGraphBatchMemOpNodeGetParams:
7633+
if (data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out) data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out__val = *(data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out);
7634+
break;
7635+
// hipGraphBatchMemOpNodeSetParams[('hipGraphNode_t', hNode),
7636+
// ('hipBatchMemOpNodeParams*', 'nodeParams')]
7637+
case HIP_API_ID_hipGraphBatchMemOpNodeSetParams:
7638+
if (data->args.hipGraphBatchMemOpNodeSetParams.nodeParams) data->args.hipGraphBatchMemOpNodeSetParams.nodeParams__val = *(data->args.hipGraphBatchMemOpNodeSetParams.nodeParams);
7639+
break;
7640+
// hipGraphExecBatchMemOpNodeSetParams[('hipGraphExec_t'. hGraphExec),
7641+
// ('hipGraphNode_t'. hNode), ('hipBatchMemOpNodeParams*', 'nodeParams')]
7642+
case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams:
7643+
if (data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams) data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams__val = *(data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams);
7644+
break;
75547645
// hipTexRefGetAddress[('hipDeviceptr_t*', 'dev_ptr'), ('const textureReference*', 'texRef')]
75557646
case HIP_API_ID_hipTexRefGetAddress:
75567647
if (data->args.hipTexRefGetAddress.dev_ptr) data->args.hipTexRefGetAddress.dev_ptr__val = *(data->args.hipTexRefGetAddress.dev_ptr);
@@ -10843,6 +10934,34 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
1084310934
oss << ", stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipWaitExternalSemaphoresAsync.stream);
1084410935
oss << ")";
1084510936
break;
10937+
case HIP_API_ID_hipGraphAddBatchMemOpNode:
10938+
oss << "hipGraphAddBatchMemOpNode(";
10939+
oss << "phGraphNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.phGraphNode);
10940+
oss << ", hGraph="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.hGraph);
10941+
oss << ", dependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.dependencies);
10942+
oss << ", numDependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.numDependencies);
10943+
oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.nodeParams);
10944+
oss << ")";
10945+
break;
10946+
case HIP_API_ID_hipGraphBatchMemOpNodeGetParams:
10947+
oss << "hipGraphBatchMemOpNodeGetParams(";
10948+
oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeGetParams.hNode);
10949+
oss << ", nodeParams_out="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out);
10950+
oss << ")";
10951+
break;
10952+
case HIP_API_ID_hipGraphBatchMemOpNodeSetParams:
10953+
oss << "hipGraphBatchMemOpNodeSetParams(";
10954+
oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeSetParams.hNode);
10955+
oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeSetParams.nodeParams);
10956+
oss << ")";
10957+
break;
10958+
case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams:
10959+
oss << "hipGraphExecBatchMemOpNodeSetParams(";
10960+
oss << "hGraphExec="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.hGraphExec);
10961+
oss << ", hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.hNode);
10962+
oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams);
10963+
oss << ")";
10964+
break;
1084610965
default: oss << "unknown";
1084710966
};
1084810967
return strdup(oss.str().c_str());

hipamd/src/amdhip.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,7 @@ hipDrvGraphMemcpyNodeSetParams
481481
hipDrvGraphMemcpyNodeGetParams
482482
hipExtHostAlloc
483483
hipStreamBatchMemOp
484+
hipGraphAddBatchMemOpNode
485+
hipGraphBatchMemOpNodeGetParams
486+
hipGraphBatchMemOpNodeSetParams
487+
hipGraphExecBatchMemOpNodeSetParams

hipamd/src/hip_api_trace.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,15 @@ hipError_t hipExternalMemoryGetMappedMipmappedArray(
804804
const hipExternalMemoryMipmappedArrayDesc* mipmapDesc);
805805
hipError_t hipDrvGraphMemcpyNodeGetParams(hipGraphNode_t hNode, HIP_MEMCPY3D* nodeParams);
806806
hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY3D* nodeParams);
807-
807+
hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
808+
const hipGraphNode_t* dependencies, size_t numDependencies,
809+
const hipBatchMemOpNodeParams* nodeParams);
810+
hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode,
811+
hipBatchMemOpNodeParams* nodeParams_out);
812+
hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode,
813+
hipBatchMemOpNodeParams* nodeParams);
814+
hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
815+
const hipBatchMemOpNodeParams* nodeParams);
808816
} // namespace hip
809817

810818
namespace hip {
@@ -1301,6 +1309,11 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
13011309
hip::hipExternalMemoryGetMappedMipmappedArray;
13021310
ptrDispatchTable->hipDrvGraphMemcpyNodeGetParams_fn = hip::hipDrvGraphMemcpyNodeGetParams;
13031311
ptrDispatchTable->hipDrvGraphMemcpyNodeSetParams_fn = hip::hipDrvGraphMemcpyNodeSetParams;
1312+
ptrDispatchTable->hipGraphAddBatchMemOpNode_fn = hip::hipGraphAddBatchMemOpNode;
1313+
ptrDispatchTable->hipGraphBatchMemOpNodeGetParams_fn = hip::hipGraphBatchMemOpNodeGetParams;
1314+
ptrDispatchTable->hipGraphBatchMemOpNodeSetParams_fn = hip::hipGraphBatchMemOpNodeSetParams;
1315+
ptrDispatchTable->hipGraphExecBatchMemOpNodeSetParams_fn =
1316+
hip::hipGraphExecBatchMemOpNodeSetParams;
13041317
}
13051318

13061319
#if HIP_ROCPROFILER_REGISTER > 0
@@ -1892,17 +1905,21 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipExtHostAlloc_fn, 461)
18921905
HIP_ENFORCE_ABI(HipDispatchTable, hipDeviceGetTexture1DLinearMaxWidth_fn, 462)
18931906
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 7
18941907
HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBatchMemOp_fn, 463);
1895-
1908+
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 8
1909+
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphAddBatchMemOpNode_fn, 464);
1910+
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphBatchMemOpNodeGetParams_fn, 465);
1911+
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphBatchMemOpNodeSetParams_fn, 466);
1912+
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecBatchMemOpNodeSetParams_fn, 467);
18961913

18971914
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
18981915
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
18991916
//
19001917
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
19011918
//
19021919
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
1903-
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 464)
1920+
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 468)
19041921

1905-
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 7,
1922+
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 8,
19061923
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
19071924
"pointers and then update this check so it is true");
19081925
#endif

hipamd/src/hip_graph.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3427,4 +3427,58 @@ hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY
34273427
HIP_RETURN(reinterpret_cast<hip::GraphDrvMemcpyNode*>(hNode)->SetParams(nodeParams));
34283428
}
34293429

3430+
hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
3431+
const hipGraphNode_t* dependencies, size_t numDependencies,
3432+
const hipBatchMemOpNodeParams* nodeParams) {
3433+
HIP_INIT_API(hipGraphAddBatchMemOpNode, phGraphNode, hGraph, dependencies, numDependencies,
3434+
nodeParams);
3435+
if (phGraphNode == nullptr || hGraph == nullptr ||
3436+
(numDependencies > 0 && dependencies == nullptr) || nodeParams == nullptr) {
3437+
HIP_RETURN(hipErrorInvalidValue);
3438+
}
3439+
hip::GraphNode* node = new hip::hipGraphBatchMemOpNode(nodeParams);
3440+
hipError_t status =
3441+
ihipGraphAddNode(node, reinterpret_cast<hip::Graph*>(hGraph),
3442+
reinterpret_cast<hip::GraphNode* const*>(dependencies), numDependencies);
3443+
*phGraphNode = reinterpret_cast<hipGraphNode_t>(node);
3444+
HIP_RETURN(status);
3445+
}
3446+
3447+
hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode,
3448+
hipBatchMemOpNodeParams* nodeParams_out) {
3449+
HIP_INIT_API(hipGraphBatchMemOpNodeGetParams, hNode, nodeParams_out);
3450+
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
3451+
if (!hip::GraphNode::isNodeValid(n) || nodeParams_out == nullptr) {
3452+
HIP_RETURN(hipErrorInvalidValue);
3453+
}
3454+
reinterpret_cast<hip::hipGraphBatchMemOpNode*>(n)->GetParams(nodeParams_out);
3455+
HIP_RETURN(hipSuccess);
3456+
}
3457+
3458+
hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode,
3459+
hipBatchMemOpNodeParams* nodeParams) {
3460+
HIP_INIT_API(hipGraphBatchMemOpNodeSetParams, hNode, nodeParams);
3461+
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
3462+
if (!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) {
3463+
HIP_RETURN(hipErrorInvalidValue);
3464+
}
3465+
HIP_RETURN(reinterpret_cast<hip::hipGraphBatchMemOpNode*>(n)->SetParams(nodeParams));
3466+
}
3467+
3468+
hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec,
3469+
hipGraphNode_t hNode,
3470+
const hipBatchMemOpNodeParams* nodeParams) {
3471+
HIP_INIT_API(hipGraphExecBatchMemOpNodeSetParams, hGraphExec, hNode, nodeParams);
3472+
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
3473+
hip::GraphExec* graphExec = reinterpret_cast<hip::GraphExec*>(hGraphExec);
3474+
if (hGraphExec == nullptr || hNode == nullptr || !hip::GraphExec::isGraphExecValid(graphExec) ||
3475+
!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) {
3476+
HIP_RETURN(hipErrorInvalidValue);
3477+
}
3478+
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(graphExec)->GetClonedNode(n);
3479+
if (clonedNode == nullptr) {
3480+
HIP_RETURN(hipErrorInvalidValue);
3481+
}
3482+
HIP_RETURN(reinterpret_cast<hip::hipGraphBatchMemOpNode*>(clonedNode)->SetParams(nodeParams));
3483+
}
34303484
} // namespace hip

0 commit comments

Comments
 (0)