| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | + #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>  | 
 | 10 | + | 
 | 11 | + #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>  | 
 | 12 | + | 
 | 13 | + #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>  | 
 | 14 | + #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>  | 
 | 15 | + | 
 | 16 | + namespace vkcompute {  | 
 | 17 | + | 
 | 18 | + using namespace utils;  | 
 | 19 | + | 
 | 20 | + void resize_tan_node(  | 
 | 21 | +  ComputeGraph* graph,  | 
 | 22 | +  const std::vector<ArgGroup>& args,  | 
 | 23 | +  const std::vector<ValueRef>& extra_args) {  | 
 | 24 | +    (void)extra_args;  | 
 | 25 | +    vTensorPtr out = graph->get_tensor(args[0].refs[0]);  | 
 | 26 | +    vTensorPtr self = graph->get_tensor(args[1].refs[0]);  | 
 | 27 | + | 
 | 28 | +    out->virtual_resize(self->sizes());  | 
 | 29 | +}  | 
 | 30 | + | 
 | 31 | + void add_tan_node(  | 
 | 32 | +  ComputeGraph& graph,  | 
 | 33 | +  const ValueRef in,  | 
 | 34 | +  const ValueRef out) {  | 
 | 35 | +    std::string kernel_name = "tan";  | 
 | 36 | +    add_dtype_suffix(kernel_name, graph.dtype_of(out));  | 
 | 37 | +    add_storage_type_suffix(kernel_name, graph.storage_type_of(out));  | 
 | 38 | + | 
 | 39 | +    vkapi::ParamsBindList ubos({});  | 
 | 40 | +    ubos.append({graph.logical_limits_ubo(out)});  | 
 | 41 | + | 
 | 42 | +    graph.execute_nodes().emplace_back(new DispatchNode(  | 
 | 43 | +        graph,  | 
 | 44 | +        VK_KERNEL_FROM_STR(kernel_name),  | 
 | 45 | +        graph.create_global_wg_size(out),  | 
 | 46 | +        graph.create_local_wg_size(out),  | 
 | 47 | +        // Inputs and Outputs  | 
 | 48 | +        {{out, vkapi::kWrite},  | 
 | 49 | +        {in, vkapi::kRead}},  | 
 | 50 | +        // Shader params buffers  | 
 | 51 | +        ubos,  | 
 | 52 | +        // Push Constants  | 
 | 53 | +        {},  | 
 | 54 | +        // Specialization Constants  | 
 | 55 | +        {},  | 
 | 56 | +        // Resize Args  | 
 | 57 | +        {},  | 
 | 58 | +        // Resizing Logic  | 
 | 59 | +        resize_tan_node));  | 
 | 60 | +}  | 
 | 61 | + | 
 | 62 | + void tan(ComputeGraph& graph, const std::vector<ValueRef>& args) {  | 
 | 63 | +  return add_tan_node(graph, args[0], args[1]);  | 
 | 64 | + }  | 
 | 65 | + | 
 | 66 | + | 
 | 67 | + REGISTER_OPERATORS {  | 
 | 68 | +   VK_REGISTER_OP(aten.tan.default, tan);  | 
 | 69 | + }  | 
 | 70 | + | 
 | 71 | + } // namespace vkcompute  | 
0 commit comments