@@ -3297,3 +3297,140 @@ TEST(VulkanComputeGraphOpsTest, test_to_copy) {
32973297    test_to_copy ();
32983298  }
32993299}
3300+ 
3301+ vkapi::ShaderInfo pick_dynamic_dispatch_shader (
3302+     ComputeGraph* graph,
3303+     const  std::vector<ArgGroup>& args,
3304+     const  std::vector<ValueRef>& additional_args) {
3305+   const  ValueRef mat1 = args[1 ].refs [0 ];
3306+ 
3307+   std::string kernel_name = " dynamic_dispatch_test" 
3308+   if  (graph->size_at <int32_t >(-2 , mat1) == 1 ) {
3309+     kernel_name += " _var1" 
3310+   } else  {
3311+     kernel_name += " _var2" 
3312+   }
3313+   return  VK_KERNEL_FROM_STR (kernel_name);
3314+ }
3315+ 
3316+ utils::uvec3 pick_dynamic_dispatch_global_wg_size (
3317+     ComputeGraph* graph,
3318+     const  std::vector<ArgGroup>& args,
3319+     const  std::vector<ValueRef>& additional_args) {
3320+   const  ValueRef out = args[0 ].refs [0 ];
3321+ 
3322+   return  graph->logical_limits_of (out);
3323+ }
3324+ 
3325+ utils::uvec3 pick_dynamic_dispatch_local_wg_size (
3326+     ComputeGraph* graph,
3327+     const  std::vector<ArgGroup>& args,
3328+     const  std::vector<ValueRef>& additional_args) {
3329+   return  {64 , 1 , 1 };
3330+ }
3331+ 
3332+ void  resize_dynamic_dispatch_node (
3333+     ComputeGraph* graph,
3334+     const  std::vector<ArgGroup>& args,
3335+     const  std::vector<ValueRef>& additional_args) {
3336+   const  ValueRef out = args[0 ].refs [0 ];
3337+   const  ValueRef mat1 = args[1 ].refs [0 ];
3338+ 
3339+   std::vector<int64_t > out_sizes = graph->sizes_of (mat1);
3340+   out_sizes.at (out_sizes.size () - 2 ) = 1 ;
3341+ 
3342+   graph->get_tensor (out)->virtual_resize (out_sizes);
3343+ }
3344+ 
3345+ void  add_dynamic_dispatch_test_node (
3346+     ComputeGraph& graph,
3347+     const  ValueRef mat1,
3348+     const  ValueRef mat2,
3349+     const  ValueRef out) {
3350+   graph.execute_nodes ().emplace_back (new  DynamicDispatchNode (
3351+       graph,
3352+       pick_dynamic_dispatch_shader,
3353+       pick_dynamic_dispatch_global_wg_size,
3354+       pick_dynamic_dispatch_local_wg_size,
3355+       //  Inputs and Outputs
3356+       {{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
3357+       //  Shader params buffers
3358+       {},
3359+       //  Push Constants
3360+       {graph.sizes_pc_of (out),
3361+        graph.sizes_pc_of (mat1),
3362+        graph.sizes_pc_of (mat2)},
3363+       //  Specialization constants
3364+       {},
3365+       //  Resize Logic
3366+       {},
3367+       resize_dynamic_dispatch_node));
3368+ }
3369+ 
3370+ vkcompute::ComputeGraph build_dynamic_dispatch_test_graph (int  M, int  N) {
3371+   using  namespace  vkcompute ; 
3372+   GraphConfig config;
3373+   ComputeGraph graph (config);
3374+ 
3375+   vkapi::ScalarType dtype = vkapi::kFloat ;
3376+   utils::StorageType in_out_stype = utils::kTexture3D ;
3377+   utils::GPUMemoryLayout memory_layout = utils::kWidthPacked ;
3378+ 
3379+   std::vector<int64_t > mat1_size = {M, N};
3380+   std::vector<int64_t > mat2_size = {M, N};
3381+   std::vector<int64_t > out_size = {1 , N};
3382+ 
3383+   IOValueRef mat1 =
3384+       graph.add_input_tensor (mat1_size, dtype, in_out_stype, memory_layout);
3385+   IOValueRef mat2{};
3386+ 
3387+   mat2.value  = graph.add_tensor (mat2_size, dtype, in_out_stype, memory_layout);
3388+   mat2.staging  = graph.set_input_tensor (mat2.value );
3389+ 
3390+   IOValueRef out;
3391+   out.value  = graph.add_tensor (out_size, dtype, in_out_stype, memory_layout);
3392+ 
3393+   add_dynamic_dispatch_test_node (graph, mat1, mat2, out);
3394+ 
3395+   out.staging  = graph.set_output_tensor (out.value );
3396+ 
3397+   return  graph;
3398+ }
3399+ 
3400+ void  test_dynamic_dispatch (int  M, int  N) {
3401+   ComputeGraph graph = build_dynamic_dispatch_test_graph (M, N);
3402+ 
3403+   graph.prepare ();
3404+   graph.encode_prepack ();
3405+   graph.prepack ();
3406+   graph.encode_execute ();
3407+ 
3408+   for  (int  i = 1 ; i < 4 ; i++) {
3409+     float  val_mat1 = i;
3410+     float  val_mat2 = i + 1 ;
3411+     //  5.3 is a hardcoded offset in the compute shader
3412+     float  val_out = M * (val_mat1 * val_mat2) + 5.5 ;
3413+     execute_graph_and_check_output (graph, {val_mat1, val_mat2}, {val_out});
3414+   }
3415+ 
3416+   //  Switch to GEMV mode
3417+   int  new_N = N / 2 ;
3418+   std::vector<int64_t > new_mat1_size = {1 , new_N};
3419+   std::vector<int64_t > new_mat2_size = {1 , new_N};
3420+   graph.resize_input (0 , new_mat1_size);
3421+   graph.resize_input (1 , new_mat2_size);
3422+   graph.propagate_resize ();
3423+ 
3424+   graph.encode_execute ();
3425+ 
3426+   for  (int  i = 1 ; i < 4 ; i++) {
3427+     float  val_mat1 = i;
3428+     float  val_mat2 = i + 1 ;
3429+     float  val_out = (val_mat1 * val_mat2) + 2.25 ;
3430+     execute_graph_and_check_output (graph, {val_mat1, val_mat2}, {val_out});
3431+   }
3432+ }
3433+ 
3434+ TEST (VulkanComputeGraphOpsTest, test_dynamic_dispatch_graph) {
3435+   test_dynamic_dispatch (128 , 128 );
3436+ }
0 commit comments