Skip to content

Commit 648ee07

Browse files
Update
[ghstack-poisoned]
1 parent dfa435a commit 648ee07

File tree

2 files changed

+658
-0
lines changed

2 files changed

+658
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
struct AOTIMetalKernelFunctionOpaque;
18+
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
19+
20+
struct AOTIMetalShaderLibraryOpaque;
21+
using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*;
22+
23+
#ifdef __cplusplus
24+
extern "C" {
25+
#endif
26+
27+
// MetalShaderLibrary functions
28+
AOTITorchError aoti_torch_mps_create_shader_library(
29+
const char* metal_shader_source,
30+
AOTIMetalShaderLibraryHandle* library_handle);
31+
32+
AOTITorchError aoti_torch_mps_delete_shader_library(
33+
AOTIMetalShaderLibraryHandle library_handle);
34+
35+
AOTITorchError aoti_torch_mps_get_kernel_function(
36+
AOTIMetalShaderLibraryHandle library_handle,
37+
const char* kernel_name,
38+
AOTIMetalKernelFunctionHandle* function_handle);
39+
40+
// MetalKernelFunction functions
41+
AOTITorchError aoti_torch_mps_start_encoding(
42+
AOTIMetalKernelFunctionHandle func);
43+
44+
AOTITorchError aoti_torch_mps_set_arg_tensor(
45+
AOTIMetalKernelFunctionHandle func,
46+
unsigned idx,
47+
AOTITensorHandle tensor);
48+
49+
AOTITorchError aoti_torch_mps_set_arg_int(
50+
AOTIMetalKernelFunctionHandle func,
51+
unsigned idx,
52+
int64_t val);
53+
54+
// Pure C dispatch functions - single value versions
55+
AOTITorchError aoti_torch_mps_dispatch_single(
56+
AOTIMetalKernelFunctionHandle func,
57+
uint64_t length);
58+
59+
AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
60+
AOTIMetalKernelFunctionHandle func,
61+
uint64_t length,
62+
uint64_t group_size);
63+
64+
// Pure C dispatch functions - array versions
65+
AOTITorchError aoti_torch_mps_dispatch_array(
66+
AOTIMetalKernelFunctionHandle func,
67+
const uint64_t* length,
68+
size_t length_size);
69+
70+
AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
71+
AOTIMetalKernelFunctionHandle func,
72+
const uint64_t* length,
73+
size_t length_size,
74+
const uint64_t* group_size,
75+
size_t group_size_size);
76+
77+
// Memory management functions
78+
AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
79+
80+
AOTITorchError aoti_torch_mps_free(void* ptr);
81+
82+
AOTITorchError aoti_torch_mps_memcpy(
83+
void* buffer,
84+
size_t constant_offset,
85+
size_t bytes_read,
86+
size_t data_size,
87+
uint8_t* constants_start);
88+
89+
AOTITorchError aoti_torch_mps_copy_buffer(
90+
void* src_buffer,
91+
void* dst_buffer,
92+
size_t data_size,
93+
size_t src_offset,
94+
size_t dst_offset);
95+
96+
// C callback function type for command block execution
97+
typedef void (*aoti_torch_mps_command_block_callback_t)(
98+
AOTIMetalKernelFunctionHandle func,
99+
void* user_data);
100+
101+
// Shared callback function for std::function trampoline
102+
void aoti_torch_mps_shared_callback(
103+
AOTIMetalKernelFunctionHandle func,
104+
void* user_data);
105+
106+
// Pure C version using function pointer and user data for trampoline pattern
107+
AOTITorchError aoti_torch_mps_run_command_block(
108+
AOTIMetalKernelFunctionHandle func,
109+
aoti_torch_mps_command_block_callback_t callback,
110+
void* user_data);
111+
112+
#ifdef __cplusplus
113+
} // extern "C"
114+
#endif
115+
116+
} // namespace metal
117+
} // namespace backends
118+
} // namespace executorch

0 commit comments

Comments
 (0)