1616
1717namespace vkcompute {
1818
19+ enum class UpsampleMode : int { NEAREST, BILINEAR };
20+
1921void resize_upsample_nearest2d_node (
2022 ComputeGraph* graph,
2123 const std::vector<ArgGroup>& args,
@@ -39,19 +41,12 @@ void resize_upsample_nearest2d_node(
3941 out->virtual_resize (out_sizes);
4042}
4143
42- // ExecuTorch-Vulkan framework to add node
43- // Args:
44- // in: will be converted from NCHW input tensor to 3D ARGB representation in
45- // openGL (via ExecuTorch) output_sizes: optional 2D array of targetting
46- // output size of H and W dimensions. >= input sizes;
47-
48- // will be computed if only given the scale_factors.
49- // scale_factors: optional 2D array of scale factors for H and W dimensions.
50- // Will be computed if only given the output_sizes.
5144void add_upsample_nearest2d_node (
5245 ComputeGraph& graph,
46+ const UpsampleMode mode,
5347 const ValueRef in,
5448 const ValueRef output_sizes,
49+ const ValueRef align_corners,
5550 const ValueRef scale_factors,
5651 const ValueRef out) {
5752 if (graph.val_is_none (output_sizes) && graph.val_is_none (scale_factors)) {
@@ -63,36 +58,61 @@ void add_upsample_nearest2d_node(
6358 " Invalid input, must provide ONLY one of output_sizes or scale_factors" );
6459 }
6560
66- vTensorPtr t_in = graph.get_tensor (in);
67- utils::uvec3 input_sizes = t_in->logical_limits ();
61+ int align_corners_val = 0 ;
62+ if (is_valid (align_corners) && graph.get_bool (align_corners)) {
63+ align_corners_val = 1 ;
64+ }
65+
66+ utils::uvec3 in_limits = graph.logical_limits_of (in);
67+ utils::uvec3 out_limits = graph.logical_limits_of (out);
68+
69+ uint32_t out_width = out_limits[0u ];
70+ uint32_t out_height = out_limits[1u ];
6871
69- utils::ivec2 input_size = {
70- utils::safe_downcast< int32_t >(input_sizes[ 0 ]),
71- utils::safe_downcast< int32_t >(input_sizes[ 1 ])};
72- utils::vec2 rev_scales = {
73- utils::safe_downcast< float >( 1.0 ), utils::safe_downcast< float >( 1.0 )} ;
72+ float scale_factor_x = float (in_limits[ 0u ]) / float (out_width);
73+ float scale_factor_y = float (in_limits[ 1u ]) / float (out_height);
74+
75+ float recip_scale_factor_x = 1 . 0f / scale_factor_x;
76+ float recip_scale_factor_y = 1 . 0f / scale_factor_y ;
7477
75- // Reverse scale factors that pre-computed before GLSL.
7678 if (!graph.val_is_none (output_sizes)) {
77- auto output_size_ref = graph.get_int_list (output_sizes);
78- rev_scales = {
79- utils::safe_downcast<float >(
80- (float )input_size[0 ] / output_size_ref->at (1 )),
81- utils::safe_downcast<float >(
82- (float )input_size[1 ] / output_size_ref->at (0 ))};
79+ IntListPtr output_size_ref = graph.get_int_list (output_sizes);
80+ out_width = output_size_ref->at (1 );
81+ out_height = output_size_ref->at (0 );
82+
83+ VK_CHECK_COND (out_width == out_limits[0u ]);
84+ VK_CHECK_COND (out_height == out_limits[1u ]);
85+
86+ } else {
87+ DoubleListPtr scales = graph.get_double_list (scale_factors);
88+ scale_factor_x = scales->at (1 );
89+ scale_factor_y = scales->at (0 );
8390
91+ VK_CHECK_COND (in_limits[0u ] * scale_factor_x == out_width);
92+ VK_CHECK_COND (in_limits[1u ] * scale_factor_y == out_height);
93+ }
94+
95+ if (align_corners_val == 1 ) {
96+ recip_scale_factor_x = float (in_limits[0u ] - 1 ) / float (out_width - 1 );
97+ recip_scale_factor_y = float (in_limits[1u ] - 1 ) / float (out_height - 1 );
8498 } else {
85- auto scales = graph.get_double_list (scale_factors);
86- rev_scales = {
87- utils::safe_downcast<float >(1.0 / scales->at (1 )),
88- utils::safe_downcast<float >(1.0 / scales->at (0 ))};
99+ recip_scale_factor_x = float (in_limits[0u ]) / float (out_width);
100+ recip_scale_factor_y = float (in_limits[1u ]) / float (out_height);
89101 }
90102
91- vTensorPtr t_out = graph. get_tensor (out) ;
103+ utils::vec2 recip_scales = {recip_scale_factor_x, recip_scale_factor_y} ;
92104
93- std::string kernel_name ( " upsample_nearest2d " ) ;
105+ std::string kernel_name;
94106 kernel_name.reserve (kShaderNameReserve );
95- add_dtype_suffix (kernel_name, *t_out);
107+ switch (mode) {
108+ case UpsampleMode::NEAREST:
109+ kernel_name = " upsample_nearest2d" ;
110+ break ;
111+ case UpsampleMode::BILINEAR:
112+ kernel_name = " upsample_bilinear2d" ;
113+ break ;
114+ }
115+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
96116
97117 graph.execute_nodes ().emplace_back (new DispatchNode (
98118 graph,
@@ -103,21 +123,44 @@ void add_upsample_nearest2d_node(
103123 {{out, vkapi::MemoryAccessType::WRITE},
104124 {in, vkapi::MemoryAccessType::READ}},
105125 // Shader params buffers
106- {t_out-> logical_limits_ubo (),
107- graph.create_params_buffer (input_size ),
108- graph.create_params_buffer (rev_scales )},
126+ {graph. logical_limits_ubo (out ),
127+ graph.logical_limits_ubo (in ),
128+ graph.create_params_buffer (recip_scales )},
109129 // Specialization Constants
110- {},
130+ {align_corners_val },
111131 resize_upsample_nearest2d_node,
112132 {output_sizes, scale_factors}));
113133}
114134
115- void upsample (ComputeGraph& graph, const std::vector<ValueRef>& args) {
116- return add_upsample_nearest2d_node (graph, args[0 ], args[1 ], args[2 ], args[3 ]);
135+ void upsample_nearest2d (
136+ ComputeGraph& graph,
137+ const std::vector<ValueRef>& args) {
138+ return add_upsample_nearest2d_node (
139+ graph,
140+ UpsampleMode::NEAREST,
141+ args[0 ],
142+ args[1 ],
143+ kDummyValueRef ,
144+ args[2 ],
145+ args[3 ]);
146+ }
147+
148+ void upsample_bilinear2d (
149+ ComputeGraph& graph,
150+ const std::vector<ValueRef>& args) {
151+ return add_upsample_nearest2d_node (
152+ graph,
153+ UpsampleMode::BILINEAR,
154+ args[0 ],
155+ args[1 ],
156+ args[2 ],
157+ args[3 ],
158+ args[4 ]);
117159}
118160
119161REGISTER_OPERATORS {
120- VK_REGISTER_OP (aten.upsample_nearest2d .vec , upsample);
162+ VK_REGISTER_OP (aten.upsample_nearest2d .vec , upsample_nearest2d);
163+ VK_REGISTER_OP (aten.upsample_bilinear2d .vec , upsample_bilinear2d);
121164}
122165
123166} // namespace vkcompute
0 commit comments