Skip to content

Commit b1a949b

Browse files
[Metax] support cudagraph on metax_gpu (PaddlePaddle#196) (PaddlePaddle#2264)
Co-authored-by: zhang-chenyi <[email protected]>
1 parent 17201b1 commit b1a949b

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed

backends/metax_gpu/runtime/runtime.cc

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "paddle/fluid/platform/profiler/cuda_tracer.h"
4040
#include "paddle/fluid/platform/profiler/cupti_data_process.h"
4141
#include "paddle/phi/api/profiler/trace_event_collector.h"
42+
#include "paddle/phi/backends/custom/cuda_graph.h"
4243
#include "paddle/phi/backends/device_base.h"
4344
#include "paddle/phi/backends/device_ext.h"
4445
#include "paddle/phi/backends/dynload/cublasLt.h"
@@ -671,6 +672,11 @@ C_Status Allocate(const C_Device device, void **ptr, size_t size) {
671672
return C_ERROR;
672673
}
673674

675+
// cudagraph temp code
676+
// cudaStream_t stream;
677+
// cudaStreamCreate(&stream);
678+
// err = cudaMallocAsync(ptr, size,cudaStreamDefault);
679+
674680
err = cudaMalloc(ptr, size);
675681
if (err != cudaSuccess) {
676682
*ptr = NULL;
@@ -1268,6 +1274,162 @@ C_Status IsDNNSupported(const C_Device device, bool *supported) {
12681274
return C_SUCCESS;
12691275
}
12701276

1277+
C_Status CudaStreamBeginCapture(const C_Device device,
1278+
C_Stream stream,
1279+
C_StreamCaptureMode mode) {
1280+
if (cudaStreamBeginCapture(cudaStream_t(stream),
1281+
cudaStreamCaptureMode(mode)) != cudaSuccess)
1282+
return C_ERROR;
1283+
return C_SUCCESS;
1284+
}
1285+
1286+
C_Status CudaStreamEndCaptrue(const C_Device device,
1287+
C_Stream stream,
1288+
C_CudaGraph *pGraph) {
1289+
if (cudaStreamEndCapture(cudaStream_t(stream),
1290+
reinterpret_cast<cudaGraph_t *>(pGraph)) !=
1291+
cudaSuccess)
1292+
return C_ERROR;
1293+
return C_SUCCESS;
1294+
}
1295+
1296+
C_Status CudaGraphGetNodes(C_CudaGraph graph,
1297+
C_CudaGraphNode *pNode,
1298+
size_t *numNodes) {
1299+
if (cudaGraphGetNodes(cudaGraph_t(graph), nullptr, numNodes) != cudaSuccess)
1300+
return C_ERROR;
1301+
return C_SUCCESS;
1302+
}
1303+
1304+
C_Status CudaGraphLaunch(const C_Device device,
1305+
C_GraphExec exec,
1306+
C_Stream stream) {
1307+
cudaError_t result =
1308+
cudaGraphLaunch(cudaGraphExec_t(exec), cudaStream_t(stream));
1309+
if (result != cudaSuccess) {
1310+
return C_ERROR;
1311+
}
1312+
return C_SUCCESS;
1313+
}
1314+
1315+
C_Status CudaGraphDestroy(C_CudaGraph graph) {
1316+
if (cudaGraphDestroy(cudaGraph_t(graph)) != cudaSuccess) return C_ERROR;
1317+
return C_SUCCESS;
1318+
}
1319+
1320+
C_Status CudaGraphExecDestroy(C_GraphExec exec) {
1321+
if (cudaGraphExecDestroy(cudaGraphExec_t(exec)) != cudaSuccess)
1322+
return C_ERROR;
1323+
return C_SUCCESS;
1324+
}
1325+
1326+
C_Status CudaGraphInstantiate(C_GraphExec *pExec,
1327+
C_CudaGraph *pGraph,
1328+
void **pErrorNode,
1329+
char *pLogBuffer,
1330+
size_t bufferSize) {
1331+
if (cudaGraphInstantiateWithFlags(reinterpret_cast<cudaGraphExec_t *>(pExec),
1332+
*(reinterpret_cast<cudaGraph_t *>(pGraph)),
1333+
cudaGraphInstantiateFlagAutoFreeOnLaunch) !=
1334+
cudaSuccess)
1335+
return C_ERROR;
1336+
return C_SUCCESS;
1337+
}
1338+
1339+
C_Status CudaStreamCaptureInfo(const C_Device device,
1340+
C_Stream stream,
1341+
C_StreamCaptureStatus *captureStatus_out,
1342+
unsigned long long *id_out, // NOLINT
1343+
C_CudaGraph *graph_out,
1344+
C_CudaGraphNode *dependencies_out,
1345+
void **edgeData_out,
1346+
size_t *numDependencies_out) {
1347+
if (cudaStreamGetCaptureInfo(
1348+
cudaStream_t(stream),
1349+
reinterpret_cast<cudaStreamCaptureStatus *>(captureStatus_out),
1350+
id_out) != cudaSuccess)
1351+
return C_ERROR;
1352+
return C_SUCCESS;
1353+
}
1354+
1355+
C_Status GetParameterSettersForExecGraph(C_CudaGraph graph,
1356+
C_GraphHookManager *c_hook) {
1357+
using parameterSetter_t =
1358+
std::function<void(phi::backends::gpu::gpuKernelParams &)>;
1359+
struct SetKernelParamsCtx {
1360+
parameterSetter_t setter_func;
1361+
cudaGraphNode_t node;
1362+
cudaKernelNodeParams params;
1363+
};
1364+
1365+
size_t num_nodes;
1366+
PADDLE_ENFORCE_GPU_SUCCESS(
1367+
cudaGraphGetNodes(cudaGraph_t(graph), nullptr, &num_nodes));
1368+
std::vector<cudaGraphNode_t> nodes(num_nodes);
1369+
PADDLE_ENFORCE_GPU_SUCCESS(
1370+
cudaGraphGetNodes(cudaGraph_t(graph), nodes.data(), &num_nodes));
1371+
1372+
std::vector<SetKernelParamsCtx *> cts_all;
1373+
std::vector<C_GraphExecuterSetter> hooks;
1374+
for (auto node : nodes) {
1375+
cudaGraphNodeType pType;
1376+
cudaError_t result = cudaGraphNodeGetType(node, &pType);
1377+
assert(result == CUDA_SUCCESS);
1378+
if (pType == cudaGraphNodeTypeKernel) {
1379+
cudaKernelNodeParams params = {};
1380+
result = cudaGraphKernelNodeGetParams(node, &params);
1381+
assert(result == CUDA_SUCCESS);
1382+
phi::backends::gpu::gpuKernelParams kernel_params(params.kernelParams);
1383+
auto kernel =
1384+
phi::backends::gpu::CUDAGraphNodeLauncher::Instance()
1385+
.parameterSetters.find(static_cast<cudaFunction_t>(params.func));
1386+
if (kernel != phi::backends::gpu::CUDAGraphNodeLauncher::Instance()
1387+
.parameterSetters.end()) {
1388+
auto launchSequence = kernel->second;
1389+
unsigned int id = kernel_params.As<int>(0);
1390+
auto parameterSetter = launchSequence.find(id);
1391+
if (parameterSetter != launchSequence.end()) {
1392+
auto setter = parameterSetter->second;
1393+
SetKernelParamsCtx *ctx = new SetKernelParamsCtx;
1394+
ctx->node = node;
1395+
ctx->params = params;
1396+
ctx->setter_func = setter;
1397+
cts_all.emplace_back(ctx);
1398+
hooks.emplace_back([](C_GraphExec exec_graph, void *userdate) {
1399+
SetKernelParamsCtx *tmp =
1400+
reinterpret_cast<SetKernelParamsCtx *>(userdate);
1401+
phi::backends::gpu::gpuKernelParams kernel_params(
1402+
tmp->params.kernelParams);
1403+
tmp->setter_func(kernel_params);
1404+
cudaGraphExecKernelNodeSetParams(
1405+
reinterpret_cast<cudaGraphExec_t>(exec_graph),
1406+
tmp->node,
1407+
&tmp->params);
1408+
});
1409+
} else {
1410+
PADDLE_THROW(common::errors::InvalidArgument(
1411+
"Error: does not find launch id"));
1412+
}
1413+
}
1414+
}
1415+
}
1416+
1417+
c_hook->size = cts_all.size();
1418+
if (cts_all.size() != 0) {
1419+
c_hook->size = cts_all.size();
1420+
c_hook->hooks = reinterpret_cast<C_GraphExecuterSetter *>(
1421+
malloc(sizeof(C_GraphExecuterSetter) * cts_all.size()));
1422+
c_hook->user_data =
1423+
reinterpret_cast<void **>(malloc(sizeof(void *) * cts_all.size()));
1424+
std::memcpy(c_hook->hooks,
1425+
hooks.data(),
1426+
cts_all.size() * sizeof(C_GraphExecuterSetter));
1427+
std::memcpy(
1428+
c_hook->user_data, cts_all.data(), cts_all.size() * sizeof(void *));
1429+
}
1430+
return C_SUCCESS;
1431+
}
1432+
12711433
void InitPlugin(CustomRuntimeParams *params) {
12721434
PADDLE_CUSTOM_RUNTIME_CHECK_VERSION(params);
12731435
params->device_type = const_cast<char *>(DeviceType);
@@ -1277,6 +1439,17 @@ void InitPlugin(CustomRuntimeParams *params) {
12771439
0,
12781440
sizeof(C_DeviceInterface));
12791441

1442+
params->interface->cuda_stream_begin_capture = CudaStreamBeginCapture;
1443+
params->interface->cuda_stream_end_captrue = CudaStreamEndCaptrue;
1444+
params->interface->cuda_graph_launch = CudaGraphLaunch;
1445+
params->interface->cuda_graph_destroy = CudaGraphDestroy;
1446+
params->interface->cuda_graph_exec_destroy = CudaGraphExecDestroy;
1447+
params->interface->cuda_graph_instantiate = CudaGraphInstantiate;
1448+
params->interface->cuda_graph_get_nodes = CudaGraphGetNodes;
1449+
params->interface->cuda_stream_capture_info = CudaStreamCaptureInfo;
1450+
params->interface->get_parameter_setter_for_exec_graph =
1451+
GetParameterSettersForExecGraph;
1452+
12801453
params->interface->get_compute_capability = GetComputeCapability;
12811454
params->interface->get_device_properties = GetDeviceProperties;
12821455
params->interface->get_runtime_version = GetRuntimeVersion;

0 commit comments

Comments
 (0)