Skip to content

configurePlugin is called repeatedly for my BF16 SDPA plugin — how to run initialization graph only once? #4654

@KarlDe1

Description

@KarlDe1

Describe the issue
I’m implementing a TensorRT plugin for SDPA with BF16 input/output.
My goal is to build the compute graph only once during initialization, so I placed the graph-construction logic inside configurePlugin().

However, I found that configurePlugin() is invoked every time the model executes, which forces the plugin to rebuild the graph repeatedly and causes significant overhead.

My expectation was that TensorRT would call configurePlugin() only during engine build time (IBuilder phase).
But in practice, it is also called again during execution (runtime phase).

More specifically:

  • Why is configurePlugin() called at runtime for each execution?

  • Is there a plugin API or recommended method to place one-time initialization logic?

  • Or is there a separate mechanism for caching prebuilt graph structures inside the plugin?

System Environment:

  • cudnn_frontend version: v1.16.0
  • cudnn_backend version: 9.16.0.29_cuda12
  • GPU arch: RTX5060Ti / Thor U
  • cuda runtime version: 12.8
  • cuda driver version: 580.95.05
  • OS: ubuntu22.04
std::shared_ptr<fe::graph::Graph> create_sdpa_forward_graph(int64_t const b,
                                                            int64_t const h_q,
                                                            int64_t const h_k,
                                                            int64_t const h_v,
                                                            int64_t const s_q,
                                                            int64_t const s_kv,
                                                            int64_t const d_qk,
                                                            int64_t const d_v,
                                                            float const attn_scale = 1.0f,
                                                            bool const causal_mask = true) {
  // Create a graph and set common global properties.
  auto graph = std::make_shared<fe::graph::Graph>();
  graph->set_io_data_type(fe::DataType_t::BFLOAT16)
      .set_intermediate_data_type(fe::DataType_t::FLOAT)
      .set_compute_data_type(fe::DataType_t::FLOAT);

  auto Q = graph->tensor(fe::graph::Tensor_attributes()
                             .set_name("Q")
                             .set_uid(Q_UID)
                             .set_dim({b, h_q, s_q, d_qk})
                             .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}));

  auto K = graph->tensor(fe::graph::Tensor_attributes()
                             .set_name("K")
                             .set_uid(K_UID)
                             .set_dim({b, h_k, s_kv, d_qk})
                             .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1}));

  auto V = graph->tensor(fe::graph::Tensor_attributes()
                             .set_name("V")
                             .set_uid(V_UID)
                             .set_dim({b, h_v, s_kv, d_v})
                             .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1}));

  auto sdpa_options =
      fe::graph::SDPA_attributes().set_name("flash_attention").set_attn_scale(attn_scale);

  if (causal_mask) {
    sdpa_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT)
        .set_diagonal_band_right_bound(0);
  }

  auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);

  O->set_output(true)
      .set_dim({b, h_q, s_q, d_v})
      .set_stride({h_q * s_q * d_v, s_q * d_v, d_v, 1})
      .set_uid(O_UID);

  return graph;
}

void SdpaCudnn::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputDesc,
                                int32_t nbInputs,
                                const nvinfer1::DynamicPluginTensorDesc* out,
                                int32_t nbOutputs) noexcept {
  // [batch, head, seq, dim]
  std::vector<int64_t> q_shape = get_shape_from_dims(inputDesc[INOUT_POS::q].desc.dims);
  std::vector<int64_t> k_shape = get_shape_from_dims(inputDesc[INOUT_POS::k].desc.dims);
  std::vector<int64_t> v_shape = get_shape_from_dims(inputDesc[INOUT_POS::v].desc.dims);

  int64_t b = q_shape[0];     // batch size
  int64_t h_q = q_shape[1];   // head dim
  int64_t h_k = k_shape[1];   // head dim
  int64_t h_v = v_shape[1];   // head dim
  int64_t s_q = q_shape[2];   // q tensor is padded to this seq length
  int64_t s_kv = k_shape[2];  // k and v tensor is padded to this seq length
  int64_t d_qk = q_shape[3];  // hidden dim
  int64_t d_v = v_shape[3];   // hidden dim

  float attn_scale = 1 / sqrt(d_qk);
  cudnn_handle_ptr_ = create_cudnn_handle();

  graph_ptr_ = create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, true);
  if (!graph_ptr_->build(*cudnn_handle_ptr_, {fe::HeurMode_t::A}).is_good()) {
    LFATAL("graph is not good.");
  }

  o_tensor_sz_ = b * h_q * s_q * d_v * sizeof(half);
  CUDA_CHECK(cudaMalloc(&o_tensor_ptr_, o_tensor_sz_));
}

int32_t SdpaCudnn::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
                           const nvinfer1::PluginTensorDesc* outputDesc,
                           const void* const* inputs,
                           void* const* outputs,
                           void* workspace,
                           cudaStream_t stream) noexcept {
  if (cudnn_handle_ptr_ == nullptr || graph_ptr_ == nullptr) {
    LFATAL("CUDNN handle or graph is not initialized.");
  }

  const void* q_ptr = inputs[INOUT_POS::q];
  const void* k_ptr = inputs[INOUT_POS::k];
  const void* v_ptr = inputs[INOUT_POS::v];

  std::unordered_map<fe::graph::Tensor_attributes::uid_t, void*> variant_pack = {
      {Q_UID, const_cast<void*>(q_ptr)},
      {K_UID, const_cast<void*>(k_ptr)},
      {V_UID, const_cast<void*>(v_ptr)},
      {O_UID, o_tensor_ptr_}};

  auto handle = *cudnn_handle_ptr_;
  cudnnSetStream(handle, stream);

  if (!graph_ptr_->execute(handle, variant_pack, workspace).is_good()) {
    LFATAL("SDPA CUDNN graph execution failed.");
  }

  CUDA_CHECK(cudaStreamSynchronize(stream));

  CUDA_CHECK(cudaMemcpyAsync(outputs[INOUT_POS::output],
                             o_tensor_ptr_,
                             o_tensor_sz_,
                             cudaMemcpyDeviceToDevice,
                             stream));
  return 0;
}

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions