@@ -32,15 +32,17 @@ PrepackNode::PrepackNode(
3232 const ValueRef tref,
3333 const ValueRef packed,
3434 const vkapi::ParamsBindList& params,
35- const vkapi::SpecVarList& spec_vars)
35+ const vkapi::SpecVarList& spec_vars,
36+ const std::vector<PushConstantDataInfo>& push_constants)
3637 : shader_(shader),
3738 noop_shader_ (get_noop_shader(graph, packed)),
3839 global_workgroup_size_(global_workgroup_size),
3940 local_workgroup_size_(local_workgroup_size),
4041 tref_(tref),
4142 packed_(packed),
4243 params_(params),
43- spec_vars_(spec_vars) {
44+ spec_vars_(spec_vars),
45+ push_constants_(push_constants) {
4446 graph.update_descriptor_counts (shader, /* execute = */ false );
4547 graph.update_descriptor_counts (noop_shader_, /* execute = */ false );
4648}
@@ -75,10 +77,20 @@ void PrepackNode::encode(ComputeGraph* graph) {
7577
7678 std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock ();
7779
80+ std::array<uint8_t , kMaxPushConstantSize > push_constants_data;
81+ uint32_t push_constants_offset = 0 ;
82+
83+ for (const auto & push_constant : push_constants_) {
84+ push_constants_offset += push_constant.write (
85+ push_constants_data.data (),
86+ push_constants_offset,
87+ kMaxPushConstantSize );
88+ }
89+
7890 {
7991 vkapi::PipelineBarrier pipeline_barrier{};
8092 vkapi::DescriptorSet descriptor_set = context->get_descriptor_set (
81- shader_, local_workgroup_size_, spec_vars_, 0u );
93+ shader_, local_workgroup_size_, spec_vars_, push_constants_offset );
8294
8395 uint32_t idx = 0 ;
8496 bind_tensor_to_descriptor_set (
@@ -91,7 +103,12 @@ void PrepackNode::encode(ComputeGraph* graph) {
91103 bind_params_to_descriptor_set (params_, descriptor_set, idx);
92104
93105 context->register_shader_dispatch (
94- descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
106+ descriptor_set,
107+ pipeline_barrier,
108+ shader_,
109+ global_workgroup_size_,
110+ push_constants_data.data (),
111+ push_constants_offset);
95112 }
96113
97114 // Submit a compute shader that performs a no-op with the packed tensor in
0 commit comments