@@ -77,7 +77,7 @@ void add_binary_op_texture_node(
7777 kernel_name.reserve (kShaderNameReserve );
7878 kernel_name += op_name;
7979 add_storage_type_suffix (kernel_name, *t_out);
80- add_dtype_suffix (kernel_name, *t_out );
80+ add_dtype_suffix (kernel_name, graph. dtype_of (in1) );
8181
8282 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
8383 graph,
@@ -121,7 +121,8 @@ void add_binary_op_buffer_node(
121121 kernel_name.reserve (kShaderNameReserve );
122122 kernel_name += op_name;
123123 add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
124- add_dtype_suffix (kernel_name, graph.dtype_of (out));
124+
125+ add_dtype_suffix (kernel_name, graph.dtype_of (in1));
125126
126127 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
127128 graph,
@@ -189,6 +190,11 @@ DEFINE_BINARY_OP_FN(mul);
189190DEFINE_BINARY_OP_FN (div);
190191DEFINE_BINARY_OP_FN (pow);
191192DEFINE_BINARY_OP_FN (minimum);
193+ DEFINE_BINARY_OP_FN (eq);
194+ DEFINE_BINARY_OP_FN (lt);
195+ DEFINE_BINARY_OP_FN (le);
196+ DEFINE_BINARY_OP_FN (gt);
197+ DEFINE_BINARY_OP_FN (ge);
192198
193199REGISTER_OPERATORS {
194200 VK_REGISTER_OP (aten.add .Tensor , add);
@@ -198,6 +204,11 @@ REGISTER_OPERATORS {
198204 VK_REGISTER_OP (aten.div .Tensor_mode , floor_divide);
199205 VK_REGISTER_OP (aten.pow .Tensor_Tensor , pow);
200206 VK_REGISTER_OP (aten.minimum .default , minimum);
207+ VK_REGISTER_OP (aten.eq .Tensor , eq);
208+ VK_REGISTER_OP (aten.lt .Tensor , lt);
209+ VK_REGISTER_OP (aten.le .Tensor , le);
210+ VK_REGISTER_OP (aten.gt .Tensor , gt);
211+ VK_REGISTER_OP (aten.ge .Tensor , ge);
201212}
202213
203214} // namespace vkcompute
0 commit comments