File tree Expand file tree Collapse file tree 4 files changed +36
-0
lines changed
Expand file tree Collapse file tree 4 files changed +36
-0
lines changed Original file line number Diff line number Diff line change @@ -25,6 +25,13 @@ AOTInductorModelContainerGetNumOutputsFunc
2525 AOTInductorModelContainerGetNumOutputs = nullptr ;
2626AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr ;
2727
28+ // Additional global function pointers for AOT Inductor model container
29+ // operations needed by Metal backend
30+ AOTInductorModelContainerGetInputNameFunc
31+ AOTInductorModelContainerGetInputName = nullptr ;
32+ AOTInductorModelContainerGetNumConstantsFunc
33+ AOTInductorModelContainerGetNumConstants = nullptr ;
34+
2835} // extern "C"
2936
3037} // namespace aoti
Original file line number Diff line number Diff line change @@ -70,6 +70,26 @@ extern AOTInductorModelContainerGetNumOutputsFunc
7070 AOTInductorModelContainerGetNumOutputs;
7171extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
7272
73+ // Retrieves the name of an input tensor by index from the AOTI model container.
74+ // Needed by Metal backend
75+ using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
76+ AOTInductorModelContainerHandle container_handle,
77+ size_t input_idx,
78+ const char ** input_name);
79+
80+ // Retrieves the number of constants from the AOTI model container.
81+ // Needed by Metal backend
82+ using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
83+ AOTInductorModelContainerHandle container_handle,
84+ size_t * num_constants);
85+
86+ // Global function pointers (will be loaded dynamically).
87+ // Needed by Metal backend
88+ extern AOTInductorModelContainerGetInputNameFunc
89+ AOTInductorModelContainerGetInputName;
90+ extern AOTInductorModelContainerGetNumConstantsFunc
91+ AOTInductorModelContainerGetNumConstants;
92+
7393} // extern "C"
7494
7595// AOTI Delegate Handle structure
Original file line number Diff line number Diff line change @@ -176,6 +176,12 @@ int32_t aoti_torch_dtype_int64() {
176176 return 4 ; // PyTorch's int64 dtype code
177177}
178178
179+ // Dtype utility function needed by Metal backend.
180+ // Returns the size of the dtype in bytes.
181+ size_t aoti_torch_dtype_element_size (int32_t dtype) {
182+ return dtype_to_element_size (dtype);
183+ }
184+
179185// Cleanup functions
180186void cleanup_tensor_metadata () {
181187 internal::tensor_to_sizes.clear ();
Original file line number Diff line number Diff line change @@ -61,6 +61,9 @@ int32_t aoti_torch_dtype_float32();
6161int32_t aoti_torch_dtype_bfloat16 ();
6262int32_t aoti_torch_dtype_int64 ();
6363
64+ // Dtype utility function needed by Metal backend
65+ size_t aoti_torch_dtype_element_size (int32_t dtype);
66+
6467// Autograd mode functions
6568int32_t aoti_torch_grad_mode_is_enabled ();
6669void aoti_torch_grad_mode_set_enabled (bool enabled);
You can’t perform that action at this time.
0 commit comments