Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,34 +108,62 @@ def preprocess(
options: dict[str, typing.Any] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we need metal_backend.py? It feels very much like cuda_backend.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. Next week we should begin refactoring, and moving most of this code to a new aoti_backend.py. Same with most of the code in metal_backend.cpp/cuda_backend.cpp

# Do not link against the full PyTorch/libtorch library
"aot_inductor.link_libtorch": False,
# Package model constants and other generated files directly in the shared object (.so) file
"aot_inductor.package_constants_in_so": True,
# Separate weight constants from the .so file
"aot_inductor.package": True,
"aot_inductor.package_constants_in_so": False,
# Store weight constants on disk in a binary blob
"aot_inductor.package_constants_on_disk_format": "binary_blob",
# Enable maximum automatic tuning for optimal performance
"max_autotune": True,
# "aot_inductor.debug_compile": True,
# "aot_inductor.force_mmap_weights": False,
}

with collect_unsupported_fallback_kernels():
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
if len(missing_fallback_kernels) > 0:
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
raise RuntimeError(
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
"Please add them to the AOTI backend."
)

# Extract the .so and .blob paths from the returned list
so_path = None
blob_path = None
for path in paths:
if path.endswith(".wrapper.so"):
so_path = path
elif path.endswith(".wrapper_weights.blob"):
blob_path = path

if so_path is None or blob_path is None:
raise RuntimeError(
f"Could not find required files in compiled paths, got {paths}"
)

# pyre-ignorep[6]: Incompatible parameter type
with open(so_path, "rb") as f:
so_data = f.read()

named_data_store = NamedDataStore()
method_name = MetalBackend.method_name_from_compile_specs(compile_specs)

# Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)

# Add weights blob to named data store
with open(blob_path, "rb") as f:
blob_data = f.read()

named_data_store.add_named_data(
method_name + "_so_blob", so_data, 1, "aoti_metal_blob"
method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob"
)

# Clean up the generated so file; it has been packaged into the NamdeDataStore
# Clean up the weights blob file
os.remove(blob_path)

# Clean up the generated so file; it has been packaged into the NamedDataStore
# pyre-ignorep[6]: Incompatible parameter type
os.remove(so_path)

Expand Down
26 changes: 26 additions & 0 deletions backends/apple/metal/runtime/metal_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class ET_EXPERIMENTAL MetalBackend final
Debug,
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerRun");

LOAD_SYMBOL(
handle,
update_constants_from_blob,
AOTInductorModelUpdateConstantsFromBlob,
so_handle);
ET_LOG(
Debug,
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelUpdateConstantsFromBlob");

ET_LOG(
Debug,
"MetalBackend::load_function_pointers_into_handle - All symbols loaded successfully");
Expand Down Expand Up @@ -203,6 +212,9 @@ class ET_EXPERIMENTAL MetalBackend final
outfile.close();
ET_LOG(Info, "MetalBackend::init - File closed successfully");

// Free the buffer immediately after writing to disk
aoti_metal_buffer->Free();

// Load the ELF using dlopen
void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
ET_CHECK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -234,6 +246,20 @@ class ET_EXPERIMENTAL MetalBackend final

handle->container_handle = container_handle;

// Look into named data map for constant data
std::string weights_blob_key =
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
const void* weights_blob = buffer_res->data();
// Feed the weights blob into the container. Under the hood it's copying
// weights, so we should free the buffer immediately.
ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
buffer_res->Free();
}

ET_LOG(Info, "MetalBackend::init - Initialization completed successfully");
return (DelegateHandle*)handle; // Return the handle post-processing
}
Expand Down
Loading