3939#include " paddle/fluid/platform/profiler/cuda_tracer.h"
4040#include " paddle/fluid/platform/profiler/cupti_data_process.h"
4141#include " paddle/phi/api/profiler/trace_event_collector.h"
42+ #include " paddle/phi/backends/custom/cuda_graph.h"
4243#include " paddle/phi/backends/device_base.h"
4344#include " paddle/phi/backends/device_ext.h"
4445#include " paddle/phi/backends/dynload/cublasLt.h"
@@ -671,6 +672,11 @@ C_Status Allocate(const C_Device device, void **ptr, size_t size) {
671672 return C_ERROR;
672673 }
673674
675+ // cudagraph temp code
676+ // cudaStream_t stream;
677+ // cudaStreamCreate(&stream);
678+ // err = cudaMallocAsync(ptr, size,cudaStreamDefault);
679+
674680 err = cudaMalloc (ptr, size);
675681 if (err != cudaSuccess) {
676682 *ptr = NULL ;
@@ -1268,6 +1274,162 @@ C_Status IsDNNSupported(const C_Device device, bool *supported) {
12681274 return C_SUCCESS;
12691275}
12701276
1277+ C_Status CudaStreamBeginCapture (const C_Device device,
1278+ C_Stream stream,
1279+ C_StreamCaptureMode mode) {
1280+ if (cudaStreamBeginCapture (cudaStream_t (stream),
1281+ cudaStreamCaptureMode (mode)) != cudaSuccess)
1282+ return C_ERROR;
1283+ return C_SUCCESS;
1284+ }
1285+
1286+ C_Status CudaStreamEndCaptrue (const C_Device device,
1287+ C_Stream stream,
1288+ C_CudaGraph *pGraph) {
1289+ if (cudaStreamEndCapture (cudaStream_t (stream),
1290+ reinterpret_cast <cudaGraph_t *>(pGraph)) !=
1291+ cudaSuccess)
1292+ return C_ERROR;
1293+ return C_SUCCESS;
1294+ }
1295+
1296+ C_Status CudaGraphGetNodes (C_CudaGraph graph,
1297+ C_CudaGraphNode *pNode,
1298+ size_t *numNodes) {
1299+ if (cudaGraphGetNodes (cudaGraph_t (graph), nullptr , numNodes) != cudaSuccess)
1300+ return C_ERROR;
1301+ return C_SUCCESS;
1302+ }
1303+
1304+ C_Status CudaGraphLaunch (const C_Device device,
1305+ C_GraphExec exec,
1306+ C_Stream stream) {
1307+ cudaError_t result =
1308+ cudaGraphLaunch (cudaGraphExec_t (exec), cudaStream_t (stream));
1309+ if (result != cudaSuccess) {
1310+ return C_ERROR;
1311+ }
1312+ return C_SUCCESS;
1313+ }
1314+
1315+ C_Status CudaGraphDestroy (C_CudaGraph graph) {
1316+ if (cudaGraphDestroy (cudaGraph_t (graph)) != cudaSuccess) return C_ERROR;
1317+ return C_SUCCESS;
1318+ }
1319+
1320+ C_Status CudaGraphExecDestroy (C_GraphExec exec) {
1321+ if (cudaGraphExecDestroy (cudaGraphExec_t (exec)) != cudaSuccess)
1322+ return C_ERROR;
1323+ return C_SUCCESS;
1324+ }
1325+
1326+ C_Status CudaGraphInstantiate (C_GraphExec *pExec,
1327+ C_CudaGraph *pGraph,
1328+ void **pErrorNode,
1329+ char *pLogBuffer,
1330+ size_t bufferSize) {
1331+ if (cudaGraphInstantiateWithFlags (reinterpret_cast <cudaGraphExec_t *>(pExec),
1332+ *(reinterpret_cast <cudaGraph_t *>(pGraph)),
1333+ cudaGraphInstantiateFlagAutoFreeOnLaunch) !=
1334+ cudaSuccess)
1335+ return C_ERROR;
1336+ return C_SUCCESS;
1337+ }
1338+
1339+ C_Status CudaStreamCaptureInfo (const C_Device device,
1340+ C_Stream stream,
1341+ C_StreamCaptureStatus *captureStatus_out,
1342+ unsigned long long *id_out, // NOLINT
1343+ C_CudaGraph *graph_out,
1344+ C_CudaGraphNode *dependencies_out,
1345+ void **edgeData_out,
1346+ size_t *numDependencies_out) {
1347+ if (cudaStreamGetCaptureInfo (
1348+ cudaStream_t (stream),
1349+ reinterpret_cast <cudaStreamCaptureStatus *>(captureStatus_out),
1350+ id_out) != cudaSuccess)
1351+ return C_ERROR;
1352+ return C_SUCCESS;
1353+ }
1354+
1355+ C_Status GetParameterSettersForExecGraph (C_CudaGraph graph,
1356+ C_GraphHookManager *c_hook) {
1357+ using parameterSetter_t =
1358+ std::function<void (phi::backends::gpu::gpuKernelParams &)>;
1359+ struct SetKernelParamsCtx {
1360+ parameterSetter_t setter_func;
1361+ cudaGraphNode_t node;
1362+ cudaKernelNodeParams params;
1363+ };
1364+
1365+ size_t num_nodes;
1366+ PADDLE_ENFORCE_GPU_SUCCESS (
1367+ cudaGraphGetNodes (cudaGraph_t (graph), nullptr , &num_nodes));
1368+ std::vector<cudaGraphNode_t> nodes (num_nodes);
1369+ PADDLE_ENFORCE_GPU_SUCCESS (
1370+ cudaGraphGetNodes (cudaGraph_t (graph), nodes.data (), &num_nodes));
1371+
1372+ std::vector<SetKernelParamsCtx *> cts_all;
1373+ std::vector<C_GraphExecuterSetter> hooks;
1374+ for (auto node : nodes) {
1375+ cudaGraphNodeType pType;
1376+ cudaError_t result = cudaGraphNodeGetType (node, &pType);
1377+ assert (result == CUDA_SUCCESS);
1378+ if (pType == cudaGraphNodeTypeKernel) {
1379+ cudaKernelNodeParams params = {};
1380+ result = cudaGraphKernelNodeGetParams (node, ¶ms);
1381+ assert (result == CUDA_SUCCESS);
1382+ phi::backends::gpu::gpuKernelParams kernel_params (params.kernelParams );
1383+ auto kernel =
1384+ phi::backends::gpu::CUDAGraphNodeLauncher::Instance ()
1385+ .parameterSetters .find (static_cast <cudaFunction_t>(params.func ));
1386+ if (kernel != phi::backends::gpu::CUDAGraphNodeLauncher::Instance ()
1387+ .parameterSetters .end ()) {
1388+ auto launchSequence = kernel->second ;
1389+ unsigned int id = kernel_params.As <int >(0 );
1390+ auto parameterSetter = launchSequence.find (id);
1391+ if (parameterSetter != launchSequence.end ()) {
1392+ auto setter = parameterSetter->second ;
1393+ SetKernelParamsCtx *ctx = new SetKernelParamsCtx;
1394+ ctx->node = node;
1395+ ctx->params = params;
1396+ ctx->setter_func = setter;
1397+ cts_all.emplace_back (ctx);
1398+ hooks.emplace_back ([](C_GraphExec exec_graph, void *userdate) {
1399+ SetKernelParamsCtx *tmp =
1400+ reinterpret_cast <SetKernelParamsCtx *>(userdate);
1401+ phi::backends::gpu::gpuKernelParams kernel_params (
1402+ tmp->params .kernelParams );
1403+ tmp->setter_func (kernel_params);
1404+ cudaGraphExecKernelNodeSetParams (
1405+ reinterpret_cast <cudaGraphExec_t>(exec_graph),
1406+ tmp->node ,
1407+ &tmp->params );
1408+ });
1409+ } else {
1410+ PADDLE_THROW (common::errors::InvalidArgument (
1411+ " Error: does not find launch id" ));
1412+ }
1413+ }
1414+ }
1415+ }
1416+
1417+ c_hook->size = cts_all.size ();
1418+ if (cts_all.size () != 0 ) {
1419+ c_hook->size = cts_all.size ();
1420+ c_hook->hooks = reinterpret_cast <C_GraphExecuterSetter *>(
1421+ malloc (sizeof (C_GraphExecuterSetter) * cts_all.size ()));
1422+ c_hook->user_data =
1423+ reinterpret_cast <void **>(malloc (sizeof (void *) * cts_all.size ()));
1424+ std::memcpy (c_hook->hooks ,
1425+ hooks.data (),
1426+ cts_all.size () * sizeof (C_GraphExecuterSetter));
1427+ std::memcpy (
1428+ c_hook->user_data , cts_all.data (), cts_all.size () * sizeof (void *));
1429+ }
1430+ return C_SUCCESS;
1431+ }
1432+
12711433void InitPlugin (CustomRuntimeParams *params) {
12721434 PADDLE_CUSTOM_RUNTIME_CHECK_VERSION (params);
12731435 params->device_type = const_cast <char *>(DeviceType);
@@ -1277,6 +1439,17 @@ void InitPlugin(CustomRuntimeParams *params) {
12771439 0 ,
12781440 sizeof (C_DeviceInterface));
12791441
1442+ params->interface ->cuda_stream_begin_capture = CudaStreamBeginCapture;
1443+ params->interface ->cuda_stream_end_captrue = CudaStreamEndCaptrue;
1444+ params->interface ->cuda_graph_launch = CudaGraphLaunch;
1445+ params->interface ->cuda_graph_destroy = CudaGraphDestroy;
1446+ params->interface ->cuda_graph_exec_destroy = CudaGraphExecDestroy;
1447+ params->interface ->cuda_graph_instantiate = CudaGraphInstantiate;
1448+ params->interface ->cuda_graph_get_nodes = CudaGraphGetNodes;
1449+ params->interface ->cuda_stream_capture_info = CudaStreamCaptureInfo;
1450+ params->interface ->get_parameter_setter_for_exec_graph =
1451+ GetParameterSettersForExecGraph;
1452+
12801453 params->interface ->get_compute_capability = GetComputeCapability;
12811454 params->interface ->get_device_properties = GetDeviceProperties;
12821455 params->interface ->get_runtime_version = GetRuntimeVersion;
0 commit comments