Skip to content

Commit f84c423

Browse files
[Metal] Update aoti_common with additional AOTI functions needed by Metal backend (#15003)
1 parent 1a8acf6 commit f84c423

File tree

4 files changed

+36
-0
lines changed

4 files changed

+36
-0
lines changed

backends/aoti/aoti_model_container.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ AOTInductorModelContainerGetNumOutputsFunc
2525
AOTInductorModelContainerGetNumOutputs = nullptr;
2626
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;
2727

28+
// Additional global function pointers for AOT Inductor model container
29+
// operations needed by Metal backend
30+
AOTInductorModelContainerGetInputNameFunc
31+
AOTInductorModelContainerGetInputName = nullptr;
32+
AOTInductorModelContainerGetNumConstantsFunc
33+
AOTInductorModelContainerGetNumConstants = nullptr;
34+
2835
} // extern "C"
2936

3037
} // namespace aoti

backends/aoti/aoti_model_container.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,26 @@ extern AOTInductorModelContainerGetNumOutputsFunc
7070
AOTInductorModelContainerGetNumOutputs;
7171
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
7272

73+
// Retrieves the name of an input tensor by index from the AOTI model container.
74+
// Needed by Metal backend
75+
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
76+
AOTInductorModelContainerHandle container_handle,
77+
size_t input_idx,
78+
const char** input_name);
79+
80+
// Retrieves the number of constants from the AOTI model container.
81+
// Needed by Metal backend
82+
using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
83+
AOTInductorModelContainerHandle container_handle,
84+
size_t* num_constants);
85+
86+
// Global function pointers (will be loaded dynamically).
87+
// Needed by Metal backend
88+
extern AOTInductorModelContainerGetInputNameFunc
89+
AOTInductorModelContainerGetInputName;
90+
extern AOTInductorModelContainerGetNumConstantsFunc
91+
AOTInductorModelContainerGetNumConstants;
92+
7393
} // extern "C"
7494

7595
// AOTI Delegate Handle structure

backends/aoti/common_shims.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ int32_t aoti_torch_dtype_int64() {
176176
return 4; // PyTorch's int64 dtype code
177177
}
178178

179+
// Dtype utility function needed by Metal backend.
180+
// Returns the size of the dtype in bytes.
181+
size_t aoti_torch_dtype_element_size(int32_t dtype) {
182+
return dtype_to_element_size(dtype);
183+
}
184+
179185
// Cleanup functions
180186
void cleanup_tensor_metadata() {
181187
internal::tensor_to_sizes.clear();

backends/aoti/common_shims.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ int32_t aoti_torch_dtype_float32();
6161
int32_t aoti_torch_dtype_bfloat16();
6262
int32_t aoti_torch_dtype_int64();
6363

64+
// Dtype utility function needed by Metal backend
65+
size_t aoti_torch_dtype_element_size(int32_t dtype);
66+
6467
// Autograd mode functions
6568
int32_t aoti_torch_grad_mode_is_enabled();
6669
void aoti_torch_grad_mode_set_enabled(bool enabled);

0 commit comments

Comments
 (0)