| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | +#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>  | 
 | 10 | + | 
 | 11 | +#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>  | 
 | 12 | +#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>  | 
 | 13 | + | 
 | 14 | +namespace vkcompute {  | 
 | 15 | + | 
 | 16 | +using namespace utils;  | 
 | 17 | + | 
 | 18 | +void resize_tan_node(  | 
 | 19 | +    ComputeGraph* graph,  | 
 | 20 | +    const std::vector<ArgGroup>& args,  | 
 | 21 | +    const std::vector<ValueRef>& extra_args) {  | 
 | 22 | +  (void)extra_args;  | 
 | 23 | +  vTensorPtr out = graph->get_tensor(args[0].refs[0]);  | 
 | 24 | +  vTensorPtr self = graph->get_tensor(args[1].refs[0]);  | 
 | 25 | + | 
 | 26 | +  out->virtual_resize(self->sizes());  | 
 | 27 | +}  | 
 | 28 | + | 
 | 29 | +void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) {  | 
 | 30 | +  std::string kernel_name = "tan";  | 
 | 31 | +  add_dtype_suffix(kernel_name, graph.dtype_of(out));  | 
 | 32 | +  add_storage_type_suffix(kernel_name, graph.storage_type_of(out));  | 
 | 33 | + | 
 | 34 | +  vkapi::ParamsBindList ubos({});  | 
 | 35 | +  ubos.append({graph.logical_limits_ubo(out)});  | 
 | 36 | + | 
 | 37 | +  graph.execute_nodes().emplace_back(new DispatchNode(  | 
 | 38 | +      graph,  | 
 | 39 | +      VK_KERNEL_FROM_STR(kernel_name),  | 
 | 40 | +      graph.create_global_wg_size(out),  | 
 | 41 | +      graph.create_local_wg_size(out),  | 
 | 42 | +      // Inputs and Outputs  | 
 | 43 | +      {{out, vkapi::kWrite}, {in, vkapi::kRead}},  | 
 | 44 | +      // Shader params buffers  | 
 | 45 | +      ubos,  | 
 | 46 | +      // Push Constants  | 
 | 47 | +      {},  | 
 | 48 | +      // Specialization Constants  | 
 | 49 | +      {},  | 
 | 50 | +      // Resize Args  | 
 | 51 | +      {},  | 
 | 52 | +      // Resizing Logic  | 
 | 53 | +      resize_tan_node));  | 
 | 54 | +}  | 
 | 55 | + | 
 | 56 | +void tan(ComputeGraph& graph, const std::vector<ValueRef>& args) {  | 
 | 57 | +  return add_tan_node(graph, args[0], args[1]);  | 
 | 58 | +}  | 
 | 59 | + | 
 | 60 | +REGISTER_OPERATORS {  | 
 | 61 | +  VK_REGISTER_OP(aten.tan.default, tan);  | 
 | 62 | +}  | 
 | 63 | + | 
 | 64 | +} // namespace vkcompute  | 
0 commit comments