-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
Describe the issue
I’m implementing a TensorRT plugin for SDPA with BF16 input/output.
My goal is to build the compute graph only once during initialization, so I placed the graph-construction logic inside configurePlugin().
However, I found that configurePlugin() is invoked every time the model executes, which forces the plugin to rebuild the graph repeatedly and causes significant overhead.
My expectation was that TensorRT would call configurePlugin() only during engine build time (IBuilder phase).
But in practice, it is also called again during execution (runtime phase).
More specifically:
-
Why is configurePlugin() called at runtime for each execution?
-
Is there a plugin API or recommended method to place one-time initialization logic?
-
Or is there a separate mechanism for caching prebuilt graph structures inside the plugin?
System Environment:
- cudnn_frontend version: v1.16.0
- cudnn_backend version: 9.16.0.29_cuda12
- GPU arch: RTX5060Ti / Thor U
- cuda runtime version: 12.8
- cuda driver version: 580.95.05
- OS: ubuntu22.04
std::shared_ptr<fe::graph::Graph> create_sdpa_forward_graph(int64_t const b,
int64_t const h_q,
int64_t const h_k,
int64_t const h_v,
int64_t const s_q,
int64_t const s_kv,
int64_t const d_qk,
int64_t const d_v,
float const attn_scale = 1.0f,
bool const causal_mask = true) {
// Create a graph and set common global properties.
auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(fe::DataType_t::BFLOAT16)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_uid(Q_UID)
.set_dim({b, h_q, s_q, d_qk})
.set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}));
auto K = graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_uid(K_UID)
.set_dim({b, h_k, s_kv, d_qk})
.set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1}));
auto V = graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_uid(V_UID)
.set_dim({b, h_v, s_kv, d_v})
.set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1}));
auto sdpa_options =
fe::graph::SDPA_attributes().set_name("flash_attention").set_attn_scale(attn_scale);
if (causal_mask) {
sdpa_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT)
.set_diagonal_band_right_bound(0);
}
auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);
O->set_output(true)
.set_dim({b, h_q, s_q, d_v})
.set_stride({h_q * s_q * d_v, s_q * d_v, d_v, 1})
.set_uid(O_UID);
return graph;
}
void SdpaCudnn::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputDesc,
int32_t nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int32_t nbOutputs) noexcept {
// [batch, head, seq, dim]
std::vector<int64_t> q_shape = get_shape_from_dims(inputDesc[INOUT_POS::q].desc.dims);
std::vector<int64_t> k_shape = get_shape_from_dims(inputDesc[INOUT_POS::k].desc.dims);
std::vector<int64_t> v_shape = get_shape_from_dims(inputDesc[INOUT_POS::v].desc.dims);
int64_t b = q_shape[0]; // batch size
int64_t h_q = q_shape[1]; // head dim
int64_t h_k = k_shape[1]; // head dim
int64_t h_v = v_shape[1]; // head dim
int64_t s_q = q_shape[2]; // q tensor is padded to this seq length
int64_t s_kv = k_shape[2]; // k and v tensor is padded to this seq length
int64_t d_qk = q_shape[3]; // hidden dim
int64_t d_v = v_shape[3]; // hidden dim
float attn_scale = 1 / sqrt(d_qk);
cudnn_handle_ptr_ = create_cudnn_handle();
graph_ptr_ = create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, true);
if (!graph_ptr_->build(*cudnn_handle_ptr_, {fe::HeurMode_t::A}).is_good()) {
LFATAL("graph is not good.");
}
o_tensor_sz_ = b * h_q * s_q * d_v * sizeof(half);
CUDA_CHECK(cudaMalloc(&o_tensor_ptr_, o_tensor_sz_));
}
int32_t SdpaCudnn::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
if (cudnn_handle_ptr_ == nullptr || graph_ptr_ == nullptr) {
LFATAL("CUDNN handle or graph is not initialized.");
}
const void* q_ptr = inputs[INOUT_POS::q];
const void* k_ptr = inputs[INOUT_POS::k];
const void* v_ptr = inputs[INOUT_POS::v];
std::unordered_map<fe::graph::Tensor_attributes::uid_t, void*> variant_pack = {
{Q_UID, const_cast<void*>(q_ptr)},
{K_UID, const_cast<void*>(k_ptr)},
{V_UID, const_cast<void*>(v_ptr)},
{O_UID, o_tensor_ptr_}};
auto handle = *cudnn_handle_ptr_;
cudnnSetStream(handle, stream);
if (!graph_ptr_->execute(handle, variant_pack, workspace).is_good()) {
LFATAL("SDPA CUDNN graph execution failed.");
}
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaMemcpyAsync(outputs[INOUT_POS::output],
o_tensor_ptr_,
o_tensor_sz_,
cudaMemcpyDeviceToDevice,
stream));
return 0;
}