Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backends/aoti/aoti_model_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs = nullptr;
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;

// Global function pointers needed by Metal backend
AOTInductorModelContainerGetInputNameFunc
AOTInductorModelContainerGetInputName = nullptr;
AOTInductorModelContainerGetNumConstantsFunc
AOTInductorModelContainerGetNumConstants = nullptr;

} // extern "C"

} // namespace aoti
Expand Down
16 changes: 16 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ extern AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs;
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;

// Function pointer types needed by Metal backend
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** input_name);

using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t* num_constants);

// Global function pointers needed by Metal backend
extern AOTInductorModelContainerGetInputNameFunc
AOTInductorModelContainerGetInputName;
extern AOTInductorModelContainerGetNumConstantsFunc
AOTInductorModelContainerGetNumConstants;

} // extern "C"

// AOTI Delegate Handle structure
Expand Down
5 changes: 5 additions & 0 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ void cleanup_tensor_metadata() {
internal::tensor_to_strides.clear();
}

// Needed by Metal backend
size_t aoti_torch_dtype_element_size(int32_t dtype) {
return dtype_to_element_size(dtype);
}

} // extern "C"

} // namespace aoti
Expand Down
3 changes: 3 additions & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ void aoti_torch_grad_mode_set_enabled(bool enabled);
// Cleanup functions for clearing global state
void cleanup_tensor_metadata();

// Needed by Metal backend
size_t aoti_torch_dtype_element_size(int32_t dtype);

} // extern "C"

} // namespace aoti
Expand Down
Loading