@@ -3297,3 +3297,140 @@ TEST(VulkanComputeGraphOpsTest, test_to_copy) {
32973297 test_to_copy ();
32983298 }
32993299}
3300+
3301+ vkapi::ShaderInfo pick_dynamic_dispatch_shader (
3302+ ComputeGraph* graph,
3303+ const std::vector<ArgGroup>& args,
3304+ const std::vector<ValueRef>& additional_args) {
3305+ const ValueRef mat1 = args[1 ].refs [0 ];
3306+
3307+ std::string kernel_name = " dynamic_dispatch_test" ;
3308+ if (graph->size_at <int32_t >(-2 , mat1) == 1 ) {
3309+ kernel_name += " _var1" ;
3310+ } else {
3311+ kernel_name += " _var2" ;
3312+ }
3313+ return VK_KERNEL_FROM_STR (kernel_name);
3314+ }
3315+
3316+ utils::uvec3 pick_dynamic_dispatch_global_wg_size (
3317+ ComputeGraph* graph,
3318+ const std::vector<ArgGroup>& args,
3319+ const std::vector<ValueRef>& additional_args) {
3320+ const ValueRef out = args[0 ].refs [0 ];
3321+
3322+ return graph->logical_limits_of (out);
3323+ }
3324+
3325+ utils::uvec3 pick_dynamic_dispatch_local_wg_size (
3326+ ComputeGraph* graph,
3327+ const std::vector<ArgGroup>& args,
3328+ const std::vector<ValueRef>& additional_args) {
3329+ return {64 , 1 , 1 };
3330+ }
3331+
3332+ void resize_dynamic_dispatch_node (
3333+ ComputeGraph* graph,
3334+ const std::vector<ArgGroup>& args,
3335+ const std::vector<ValueRef>& additional_args) {
3336+ const ValueRef out = args[0 ].refs [0 ];
3337+ const ValueRef mat1 = args[1 ].refs [0 ];
3338+
3339+ std::vector<int64_t > out_sizes = graph->sizes_of (mat1);
3340+ out_sizes.at (out_sizes.size () - 2 ) = 1 ;
3341+
3342+ graph->get_tensor (out)->virtual_resize (out_sizes);
3343+ }
3344+
3345+ void add_dynamic_dispatch_test_node (
3346+ ComputeGraph& graph,
3347+ const ValueRef mat1,
3348+ const ValueRef mat2,
3349+ const ValueRef out) {
3350+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
3351+ graph,
3352+ pick_dynamic_dispatch_shader,
3353+ pick_dynamic_dispatch_global_wg_size,
3354+ pick_dynamic_dispatch_local_wg_size,
3355+ // Inputs and Outputs
3356+ {{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
3357+ // Shader params buffers
3358+ {},
3359+ // Push Constants
3360+ {graph.sizes_pc_of (out),
3361+ graph.sizes_pc_of (mat1),
3362+ graph.sizes_pc_of (mat2)},
3363+ // Specialization constants
3364+ {},
3365+ // Resize Logic
3366+ {},
3367+ resize_dynamic_dispatch_node));
3368+ }
3369+
3370+ vkcompute::ComputeGraph build_dynamic_dispatch_test_graph (int M, int N) {
3371+ using namespace vkcompute ;
3372+ GraphConfig config;
3373+ ComputeGraph graph (config);
3374+
3375+ vkapi::ScalarType dtype = vkapi::kFloat ;
3376+ utils::StorageType in_out_stype = utils::kTexture3D ;
3377+ utils::GPUMemoryLayout memory_layout = utils::kWidthPacked ;
3378+
3379+ std::vector<int64_t > mat1_size = {M, N};
3380+ std::vector<int64_t > mat2_size = {M, N};
3381+ std::vector<int64_t > out_size = {1 , N};
3382+
3383+ IOValueRef mat1 =
3384+ graph.add_input_tensor (mat1_size, dtype, in_out_stype, memory_layout);
3385+ IOValueRef mat2{};
3386+
3387+ mat2.value = graph.add_tensor (mat2_size, dtype, in_out_stype, memory_layout);
3388+ mat2.staging = graph.set_input_tensor (mat2.value );
3389+
3390+ IOValueRef out;
3391+ out.value = graph.add_tensor (out_size, dtype, in_out_stype, memory_layout);
3392+
3393+ add_dynamic_dispatch_test_node (graph, mat1, mat2, out);
3394+
3395+ out.staging = graph.set_output_tensor (out.value );
3396+
3397+ return graph;
3398+ }
3399+
3400+ void test_dynamic_dispatch (int M, int N) {
3401+ ComputeGraph graph = build_dynamic_dispatch_test_graph (M, N);
3402+
3403+ graph.prepare ();
3404+ graph.encode_prepack ();
3405+ graph.prepack ();
3406+ graph.encode_execute ();
3407+
3408+ for (int i = 1 ; i < 4 ; i++) {
3409+ float val_mat1 = i;
3410+ float val_mat2 = i + 1 ;
3411+ // 5.3 is a hardcoded offset in the compute shader
3412+ float val_out = M * (val_mat1 * val_mat2) + 5.5 ;
3413+ execute_graph_and_check_output (graph, {val_mat1, val_mat2}, {val_out});
3414+ }
3415+
3416+ // Switch to GEMV mode
3417+ int new_N = N / 2 ;
3418+ std::vector<int64_t > new_mat1_size = {1 , new_N};
3419+ std::vector<int64_t > new_mat2_size = {1 , new_N};
3420+ graph.resize_input (0 , new_mat1_size);
3421+ graph.resize_input (1 , new_mat2_size);
3422+ graph.propagate_resize ();
3423+
3424+ graph.encode_execute ();
3425+
3426+ for (int i = 1 ; i < 4 ; i++) {
3427+ float val_mat1 = i;
3428+ float val_mat2 = i + 1 ;
3429+ float val_out = (val_mat1 * val_mat2) + 2.25 ;
3430+ execute_graph_and_check_output (graph, {val_mat1, val_mat2}, {val_out});
3431+ }
3432+ }
3433+
3434+ TEST (VulkanComputeGraphOpsTest, test_dynamic_dispatch_graph) {
3435+ test_dynamic_dispatch (128 , 128 );
3436+ }
0 commit comments