Skip to content

Commit 09f2182

Browse files
authored
Merge branch 'main' into expand-copy
2 parents ba63691 + fdfeaa4 commit 09f2182

File tree

97 files changed

+1478
-654
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+1478
-654
lines changed

.github/workflows/cuda.yml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@ jobs:
176176
matrix:
177177
quant:
178178
- name: "non-quantized"
179-
artifact: "voxtral-cuda-export"
179+
artifact: "gemma3-cuda-export"
180180
extra_args: ""
181-
# TODO: enable gemma3 quantization
182-
# - name: "quantized-int4-tile-packed"
183-
# artifact: "voxtral-cuda-quantized-int4-tile-packed"
184-
# extra_args: "--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d"
181+
- name: "quantized-int4-tile-packed"
182+
artifact: "gemma3-cuda-quantized-int4-tile-packed"
183+
extra_args: "--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d"
184+
# TODO: enable int4-weight-only on gemma3.
185185
# - name: "quantized-int4-weight-only"
186186
# artifact: "voxtral-cuda-quantized-int4-weight-only"
187187
# # TODO: adding "--qlinear 4w" produces invalid results. Need further investigation.
@@ -194,7 +194,7 @@ jobs:
194194
gpu-arch-version: 12.6
195195
use-custom-docker-registry: false
196196
submodules: recursive
197-
upload-artifact: gemma3-cuda-export
197+
upload-artifact: ${{ matrix.quant.artifact }}
198198
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
199199
script: |
200200
set -eux
@@ -255,7 +255,7 @@ jobs:
255255
set -eux
256256
257257
echo "::group::Setup ExecuTorch Requirements"
258-
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh
258+
./install_requirements.sh
259259
pip list
260260
echo "::endgroup::"
261261
@@ -305,7 +305,7 @@ jobs:
305305
set -eux
306306
307307
echo "::group::Setup ExecuTorch Requirements"
308-
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh
308+
./install_requirements.sh
309309
pip list
310310
echo "::endgroup::"
311311
@@ -363,7 +363,7 @@ jobs:
363363
set -eux
364364
365365
echo "::group::Setup ExecuTorch Requirements"
366-
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh
366+
./install_requirements.sh
367367
pip list
368368
echo "::endgroup::"
369369
@@ -435,9 +435,9 @@ jobs:
435435
format:
436436
- name: "non-quantized"
437437
artifact: "gemma3-cuda-export"
438-
# TODO: enable quantized gemma3.
439-
# - name: "quantized-int4-tile-packed"
440-
# artifact: "gemma3-cuda-quantized-int4-tile-packed"
438+
- name: "quantized-int4-tile-packed"
439+
artifact: "gemma3-cuda-quantized-int4-tile-packed"
440+
# TODO: enable int4-weight-only on gemma3.
441441
# - name: "quantized-int4-weight-only"
442442
# artifact: "gemma3-cuda-quantized-int4-weight-only"
443443
with:

.mypy.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ ignore_missing_imports = True
8383
[mypy-tosa_tools.*]
8484
ignore_missing_imports = True
8585

86+
[mypy-tosa_serializer]
87+
ignore_missing_imports = True
88+
89+
[mypy-tosa_serializer.*]
90+
ignore_missing_imports = True
91+
8692
[mypy-setuptools.*]
8793
ignore_missing_imports = True
8894

backends/apple/metal/metal_backend.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
# exist fallback operators in et namespace;
3131
supported_fallback_kernels: Dict[str, Any] = {
32-
"aoti_torch_mps_addmm_out": None,
3332
"aoti_torch_mps_convolution": None,
3433
"aoti_torch_mps_mm_out": None,
3534
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
@@ -108,34 +107,62 @@ def preprocess(
108107
options: dict[str, typing.Any] = {
109108
# Do not link against the full PyTorch/libtorch library
110109
"aot_inductor.link_libtorch": False,
111-
# Package model constants and other generated files directly in the shared object (.so) file
112-
"aot_inductor.package_constants_in_so": True,
110+
# Separate weight constants from the .so file
111+
"aot_inductor.package": True,
112+
"aot_inductor.package_constants_in_so": False,
113+
# Store weight constants on disk in a binary blob
114+
"aot_inductor.package_constants_on_disk_format": "binary_blob",
113115
# Enable maximum automatic tuning for optimal performance
114116
"max_autotune": True,
115117
# "aot_inductor.debug_compile": True,
116118
# "aot_inductor.force_mmap_weights": False,
117119
}
118120

119121
with collect_unsupported_fallback_kernels():
120-
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
122+
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
121123
if len(missing_fallback_kernels) > 0:
122124
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
123125
raise RuntimeError(
124126
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
125127
"Please add them to the AOTI backend."
126128
)
127129

130+
# Extract the .so and .blob paths from the returned list
131+
so_path = None
132+
blob_path = None
133+
for path in paths:
134+
if path.endswith(".wrapper.so"):
135+
so_path = path
136+
elif path.endswith(".wrapper_weights.blob"):
137+
blob_path = path
138+
139+
if so_path is None or blob_path is None:
140+
raise RuntimeError(
141+
f"Could not find required files in compiled paths, got {paths}"
142+
)
143+
128144
# pyre-ignorep[6]: Incompatible parameter type
129145
with open(so_path, "rb") as f:
130146
so_data = f.read()
131147

132148
named_data_store = NamedDataStore()
133149
method_name = MetalBackend.method_name_from_compile_specs(compile_specs)
150+
151+
# Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
152+
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
153+
154+
# Add weights blob to named data store
155+
with open(blob_path, "rb") as f:
156+
blob_data = f.read()
157+
134158
named_data_store.add_named_data(
135-
method_name + "_so_blob", so_data, 1, "aoti_metal_blob"
159+
method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob"
136160
)
137161

138-
# Clean up the generated so file; it has been packaged into the NamdeDataStore
162+
# Clean up the weights blob file
163+
os.remove(blob_path)
164+
165+
# Clean up the generated so file; it has been packaged into the NamedDataStore
139166
# pyre-ignorep[6]: Incompatible parameter type
140167
os.remove(so_path)
141168

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,15 @@ class ET_EXPERIMENTAL MetalBackend final
106106
Debug,
107107
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerRun");
108108

109+
LOAD_SYMBOL(
110+
handle,
111+
update_constants_from_blob,
112+
AOTInductorModelUpdateConstantsFromBlob,
113+
so_handle);
114+
ET_LOG(
115+
Debug,
116+
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelUpdateConstantsFromBlob");
117+
109118
ET_LOG(
110119
Debug,
111120
"MetalBackend::load_function_pointers_into_handle - All symbols loaded successfully");
@@ -203,6 +212,9 @@ class ET_EXPERIMENTAL MetalBackend final
203212
outfile.close();
204213
ET_LOG(Info, "MetalBackend::init - File closed successfully");
205214

215+
// Free the buffer immediately after writing to disk
216+
aoti_metal_buffer->Free();
217+
206218
// Load the ELF using dlopen
207219
void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
208220
ET_CHECK_OR_RETURN_ERROR(
@@ -234,6 +246,20 @@ class ET_EXPERIMENTAL MetalBackend final
234246

235247
handle->container_handle = container_handle;
236248

249+
// Look into named data map for constant data
250+
std::string weights_blob_key =
251+
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
252+
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
253+
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
254+
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
255+
const void* weights_blob = buffer_res->data();
256+
// Feed the weights blob into the container. Under the hood it's copying
257+
// weights, so we should free the buffer immediately.
258+
ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
259+
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
260+
buffer_res->Free();
261+
}
262+
237263
ET_LOG(Info, "MetalBackend::init - Initialization completed successfully");
238264
return (DelegateHandle*)handle; // Return the handle post-processing
239265
}

backends/apple/metal/runtime/shims/et_metal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ extern "C" {
354354

355355
// Memory management functions for Metal
356356
void* metal_allocate_buffer(long bytes);
357+
void metal_deallocate_buffer(void* ptr);
357358
bool metal_is_device_pointer(void* ptr);
358359
int metal_copy_memory(
359360
void* dst,

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@ void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
8686
}
8787
}
8888

89+
void metal_deallocate_buffer(void* ptr) {
90+
@autoreleasepool {
91+
auto it = ptr_to_mtl_buffer.find(ptr);
92+
if (it != ptr_to_mtl_buffer.end()) {
93+
id<MTLBuffer> buffer = it->second;
94+
[buffer release];
95+
ptr_to_mtl_buffer.erase(it);
96+
ET_LOG(Debug, "Deallocated Metal buffer for pointer %p", ptr);
97+
ptr = nullptr;
98+
} else {
99+
ET_LOG(Error, "Failed to find Metal buffer for pointer %p", ptr);
100+
}
101+
}
102+
}
103+
89104
void metal_cleanup_resources() {
90105
if (!ptr_to_mtl_buffer.empty()) {
91106
@autoreleasepool {
@@ -665,12 +680,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
665680

666681
// Commit methods
667682
void ETMetalStream::commit() {
668-
if (enableCommitAndContinue_ && commandBuffer_) {
669-
// Use commit-and-continue for better performance
670-
commitAndContinue();
671-
} else {
672-
flush();
683+
if (!commandBuffer_) {
684+
ET_LOG(Error, "ETMetalStream::commit: No command buffer to commit");
685+
return;
673686
}
687+
688+
[commandBuffer_ commit];
689+
ET_LOG(Debug, "ETMetalStream::commit: Committed buffer %p", commandBuffer_);
690+
691+
[commandBuffer_ release];
692+
commandBuffer_ = nil;
674693
}
675694

676695
void ETMetalStream::commitAndWait() {

0 commit comments

Comments
 (0)