Skip to content

Commit a3cc24d

Browse files
authored
Enable cuda graph for LLMs for NvTensorRtRtx EP (microsoft#1645)
- Enable cuda graph for LLMs for NvTensorRtRtx EP - change the flag name from nv_cuda_graph_enable to enable_cuda_graph
1 parent 5c15fe2 commit a3cc24d

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/models/model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
584584
bool is_multi_profile_enabled = IsMultiProfileEnabled(config.model.decoder.session_options);
585585
ConfigureNvTensorRtRTxProfile(config, session_options, is_multi_profile_enabled);
586586
if (IsGraphCaptureEnabled(config.model.decoder.session_options)) {
587-
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_cuda_graph_enable", "1");
587+
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.enable_cuda_graph", "1");
588588
}
589589
p_device = GetDeviceInterface(DeviceType::NvTensorRtRtx);
590590
}

src/python/py/models/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
},
9595
"dml": {},
9696
"webgpu": {},
97-
"NvTensorRtRtx": {},
97+
"NvTensorRtRtx": {"enable_cuda_graph": "1"} if extra_options.get("enable_cuda_graph", False) else {},
9898
}
9999

100100
# Map input names to their types and shapes

0 commit comments

Comments
 (0)