-
Notifications
You must be signed in to change notification settings - Fork 734
Metal backend: Enable Float16 #15947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15947
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables Float16 (Half precision) support in the Metal backend, complementing the existing BFloat16 support. The changes add Float16 as a supported data type across the Metal backend infrastructure and provide conversion utilities for transforming Float32 tensors to Float16.
Key Changes:
- Added
convert_to_float16utility function mirroring the existingconvert_to_bfloat16pattern - Enabled Float16 dtype code (5) in Metal backend type system and validation
- Extended Metal operations (matrix multiplication, convolution, attention) to handle Float16 data
- Updated CI/CD workflows to test both float16 and bfloat16 dtypes
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
extension/llm/runner/util.h |
Added convert_to_float16 helper function for Float→Half tensor conversion |
extension/asr/runner/runner.cpp |
Added Float16 conversion support in ASR audio feature preprocessing |
backends/apple/metal/runtime/shims/utils.h |
Uncommented FLOAT16 enum value to enable Float16 dtype |
backends/apple/metal/runtime/shims/utils.cpp |
Added Float16 to supported dtypes validation |
backends/apple/metal/runtime/shims/et_metal_ops.mm |
Added Float16 handling in mm, convolution, and attention operations |
backends/aoti/utils.h |
Added dtype code 5 mapping to ScalarType::Half |
backends/aoti/common_shims.h |
Added aoti_torch_dtype_float16 function declaration |
backends/aoti/common_shims.cpp |
Implemented aoti_torch_dtype_float16 returning dtype code 5 |
.github/workflows/metal.yml |
Added dtype matrix parameter (float16/bfloat16) for testing |
.github/workflows/cuda.yml |
Added bfloat16 dtype parameter to script calls for consistency |
.ci/scripts/export_model_artifact.sh |
Added dtype parameter and validation to export script |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
.github/workflows/cuda.yml
Outdated
| ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} | ||
| script: | | ||
| source .ci/scripts/test_model_e2e.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" | ||
| source .ci/scripts/test_model_e2e.sh cuda bfloat16 "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script test_model_e2e.sh is being called with a new signature that includes dtype as the second parameter: test_model_e2e.sh cuda bfloat16 "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}". However, based on the current implementation of test_model_e2e.sh (lines 60-62), it expects the old signature: <device> <hf_model> <quant_name> [model_dir]. This mismatch will cause the script to fail as it will interpret bfloat16 as the HF model name and the actual model name as the quant parameter. The test_model_e2e.sh script needs to be updated to accept and handle the dtype parameter in the same way export_model_artifact.sh was updated.
.github/workflows/metal.yml
Outdated
| echo "::endgroup::" | ||
| ${CONDA_RUN} bash .ci/scripts/test_model_e2e.sh metal "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" | ||
| ${CONDA_RUN} bash .ci/scripts/test_model_e2e.sh metal "${{ matrix.dtype }}" "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script test_model_e2e.sh is being called with a new signature that includes dtype as the second parameter: test_model_e2e.sh metal "${{ matrix.dtype }}" "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}". However, the current implementation of test_model_e2e.sh (lines 60-62) expects the old signature: <device> <hf_model> <quant_name> [model_dir]. This mismatch will cause the script to fail as it will interpret dtype as the HF model name and the actual model name as the quant parameter. The test_model_e2e.sh script needs to be updated to accept and handle the dtype parameter consistently with how export_model_artifact.sh was updated.
| ET_LOG(Debug, "aoti_torch_mps_mm_out: self_tensor scalar_type=%d, SupportedDTypes::FLOAT32=%d, SupportedDTypes::FLOAT16=%d, SupportedDTypes::BFLOAT16=%d", | ||
| dtype, static_cast<int32_t>(SupportedDTypes::FLOAT32), static_cast<int32_t>(SupportedDTypes::FLOAT16), static_cast<int32_t>(SupportedDTypes::BFLOAT16)); | ||
|
|
||
| if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very well familiar with ET coding practice, but considering you have no modify it in two places, why not have to_mps_dtype(SupportedDtypes) inline function and call it here and few hundrend lines down below?
| } // namespace llm | ||
| } // namespace extension | ||
| } // namespace executorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: if ET is a C++17 compatible project, why not use nested namespaces?
Enables Float16 in the metal backend