1616
1717namespace vkcompute {
1818
19+ enum class UpsampleMode : int { NEAREST, BILINEAR };
20+
1921void resize_upsample_nearest2d_node (
2022 ComputeGraph* graph,
2123 const std::vector<ArgGroup>& args,
@@ -39,19 +41,12 @@ void resize_upsample_nearest2d_node(
3941 out->virtual_resize (out_sizes);
4042}
4143
42- // ExecuTorch-Vulkan framework to add node
43- // Args:
44- // in: will be converted from NCHW input tensor to 3D ARGB representation in
45- // openGL (via ExecuTorch) output_sizes: optional 2D array of targetting
46- // output size of H and W dimensions. >= input sizes;
47-
48- // will be computed if only given the scale_factors.
49- // scale_factors: optional 2D array of scale factors for H and W dimensions.
50- // Will be computed if only given the output_sizes.
5144void add_upsample_nearest2d_node (
5245 ComputeGraph& graph,
46+ const UpsampleMode mode,
5347 const ValueRef in,
5448 const ValueRef output_sizes,
49+ const ValueRef align_corners,
5550 const ValueRef scale_factors,
5651 const ValueRef out) {
5752 if (graph.val_is_none (output_sizes) && graph.val_is_none (scale_factors)) {
@@ -62,37 +57,51 @@ void add_upsample_nearest2d_node(
6257 VK_THROW (
6358 " Invalid input, must provide ONLY one of output_sizes or scale_factors" );
6459 }
60+ utils::uvec3 in_limits = graph.logical_limits_of (in);
61+ utils::uvec3 out_limits = graph.logical_limits_of (out);
6562
66- vTensorPtr t_in = graph. get_tensor (in) ;
67- utils::uvec3 input_sizes = t_in-> logical_limits () ;
63+ uint32_t out_width = out_limits[ 0u ] ;
64+ uint32_t out_height = out_limits[ 1u ] ;
6865
69- utils::ivec2 input_size = {
70- utils::safe_downcast< int32_t >(input_sizes[ 0 ]),
71- utils::safe_downcast< int32_t >(input_sizes[ 1 ])};
72- utils::vec2 rev_scales = {
73- utils::safe_downcast< float >( 1.0 ), utils::safe_downcast< float >( 1.0 )} ;
66+ float scale_factor_x = float (in_limits[ 0u ]) / float (out_width);
67+ float scale_factor_y = float (in_limits[ 1u ]) / float (out_height);
68+
69+ float recip_scale_factor_x = 1 . 0f / scale_factor_x;
70+ float recip_scale_factor_y = 1 . 0f / scale_factor_y ;
7471
75- // Reverse scale factors that pre-computed before GLSL.
7672 if (!graph.val_is_none (output_sizes)) {
77- auto output_size_ref = graph.get_int_list (output_sizes);
78- rev_scales = {
79- utils::safe_downcast< float >(
80- ( float )input_size[ 0 ] / output_size_ref-> at ( 1 )),
81- utils::safe_downcast< float >(
82- ( float )input_size[ 1 ] / output_size_ref-> at ( 0 ))} ;
73+ IntListPtr output_size_ref = graph.get_int_list (output_sizes);
74+ out_width = output_size_ref-> at ( 1 );
75+ out_height = output_size_ref-> at ( 0 );
76+
77+ VK_CHECK_COND (out_width == out_limits[ 0u ]);
78+ VK_CHECK_COND (out_height == out_limits[ 1u ]) ;
8379
8480 } else {
85- auto scales = graph.get_double_list (scale_factors);
86- rev_scales = {
87- utils::safe_downcast<float >(1.0 / scales->at (1 )),
88- utils::safe_downcast<float >(1.0 / scales->at (0 ))};
81+ DoubleListPtr scales = graph.get_double_list (scale_factors);
82+ scale_factor_x = scales->at (1 );
83+ scale_factor_y = scales->at (0 );
84+
85+ VK_CHECK_COND (in_limits[0u ] * scale_factor_x == out_width);
86+ VK_CHECK_COND (in_limits[1u ] * scale_factor_y == out_height);
8987 }
9088
91- vTensorPtr t_out = graph.get_tensor (out);
89+ recip_scale_factor_x = float (in_limits[0u ] - 1 ) / float (out_width - 1 );
90+ recip_scale_factor_y = float (in_limits[1u ] - 1 ) / float (out_height - 1 );
91+
92+ utils::vec2 recip_scales = {recip_scale_factor_x, recip_scale_factor_y};
9293
93- std::string kernel_name ( " upsample_nearest2d " ) ;
94+ std::string kernel_name;
9495 kernel_name.reserve (kShaderNameReserve );
95- add_dtype_suffix (kernel_name, *t_out);
96+ switch (mode) {
97+ case UpsampleMode::NEAREST:
98+ kernel_name = " upsample_nearest2d" ;
99+ break ;
100+ case UpsampleMode::BILINEAR:
101+ kernel_name = " upsample_bilinear2d" ;
102+ break ;
103+ }
104+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
96105
97106 graph.execute_nodes ().emplace_back (new DispatchNode (
98107 graph,
@@ -103,21 +112,44 @@ void add_upsample_nearest2d_node(
103112 {{out, vkapi::MemoryAccessType::WRITE},
104113 {in, vkapi::MemoryAccessType::READ}},
105114 // Shader params buffers
106- {t_out-> logical_limits_ubo (),
107- graph.create_params_buffer (input_size ),
108- graph.create_params_buffer (rev_scales )},
115+ {graph. logical_limits_ubo (out ),
116+ graph.logical_limits_ubo (in ),
117+ graph.create_params_buffer (recip_scales )},
109118 // Specialization Constants
110119 {},
111120 resize_upsample_nearest2d_node,
112121 {output_sizes, scale_factors}));
113122}
114123
115- void upsample (ComputeGraph& graph, const std::vector<ValueRef>& args) {
116- return add_upsample_nearest2d_node (graph, args[0 ], args[1 ], args[2 ], args[3 ]);
124+ void upsample_nearest2d (
125+ ComputeGraph& graph,
126+ const std::vector<ValueRef>& args) {
127+ return add_upsample_nearest2d_node (
128+ graph,
129+ UpsampleMode::NEAREST,
130+ args[0 ],
131+ args[1 ],
132+ kDummyValueRef ,
133+ args[2 ],
134+ args[3 ]);
135+ }
136+
137+ void upsample_bilinear2d (
138+ ComputeGraph& graph,
139+ const std::vector<ValueRef>& args) {
140+ return add_upsample_nearest2d_node (
141+ graph,
142+ UpsampleMode::BILINEAR,
143+ args[0 ],
144+ args[1 ],
145+ args[2 ],
146+ args[3 ],
147+ args[4 ]);
117148}
118149
119150REGISTER_OPERATORS {
120- VK_REGISTER_OP (aten.upsample_nearest2d .vec , upsample);
151+ VK_REGISTER_OP (aten.upsample_nearest2d .vec , upsample_nearest2d);
152+ VK_REGISTER_OP (aten.upsample_bilinear2d .vec , upsample_bilinear2d);
121153}
122154
123155} // namespace vkcompute
0 commit comments