Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ fi
set -eux

DEVICE="$1"
HF_MODEL="$2"
QUANT_NAME="${3:-non-quantized}"
OUTPUT_DIR="${4:-.}"
DTYPE="$2"
HF_MODEL="$3"
QUANT_NAME="${4:-non-quantized}"
OUTPUT_DIR="${5:-.}"

case "$DEVICE" in
cuda)
Expand All @@ -67,6 +68,18 @@ case "$DEVICE" in
;;
esac

case "$DTYPE" in
float16)
;;
bfloat16)
;;
*)
echo "Error: Unsupported dtype '$DTYPE'"
echo "Supported dtypes: float16, bfloat16"
exit 1
;;
esac

# Determine model configuration based on HF model ID
case "$HF_MODEL" in
mistralai/Voxtral-Mini-3B-2507)
Expand Down Expand Up @@ -155,7 +168,7 @@ optimum-cli export executorch \
--model "$HF_MODEL" \
--task "$TASK" \
--recipe "$DEVICE" \
--dtype bfloat16 \
--dtype "$DTYPE" \
${DEVICE_ARG} \
${MAX_SEQ_LEN_ARG} \
${EXTRA_ARGS} \
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
echo "::endgroup::"
source .ci/scripts/export_model_artifact.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
source .ci/scripts/export_model_artifact.sh cuda bfloat16 "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
test-model-cuda-e2e:
name: test-model-cuda-e2e
Expand Down
12 changes: 9 additions & 3 deletions .github/workflows/metal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ jobs:
name: "whisper-small"
- repo: "openai"
name: "whisper-large-v3-turbo"
dtype:
- "float16"
- "bfloat16"
quant:
- "non-quantized"
with:
Expand All @@ -53,7 +56,7 @@ jobs:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
secrets-env: EXECUTORCH_HF_TOKEN
upload-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-metal-${{ matrix.quant }}
upload-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-metal-${{ matrix.dtype }}-${{ matrix.quant }}
script: |
set -eux
Expand All @@ -76,7 +79,7 @@ jobs:
${CONDA_RUN} pip list
echo "::endgroup::"
${CONDA_RUN} bash .ci/scripts/export_model_artifact.sh metal "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
${CONDA_RUN} bash .ci/scripts/export_model_artifact.sh metal "${{ matrix.dtype }}" "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
test-model-metal-e2e:
name: test-model-metal-e2e
Expand All @@ -92,6 +95,9 @@ jobs:
name: "whisper-small"
- repo: "openai"
name: "whisper-large-v3-turbo"
dtype:
- "float16"
- "bfloat16"
quant:
- "non-quantized"
with:
Expand All @@ -100,7 +106,7 @@ jobs:
submodules: 'recursive'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
download-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-metal-${{ matrix.quant }}
download-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-metal-${{ matrix.dtype }}-${{ matrix.quant }}
script: |
set -eux
Expand Down
4 changes: 4 additions & 0 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ int32_t aoti_torch_dtype_float32() {
return 6; // PyTorch's float32 dtype code
}

int32_t aoti_torch_dtype_float16() {
return 5; // PyTorch's float16 dtype code
}

int32_t aoti_torch_dtype_bfloat16() {
return 15; // PyTorch's bfloat16 dtype code
}
Expand Down
1 change: 1 addition & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
return executorch::aten::ScalarType::Int;
case 4: // PyTorch's int64 dtype code
return executorch::aten::ScalarType::Long;
case 5: // PyTorch's float16 dtype code
return executorch::aten::ScalarType::Half;
case 6: // PyTorch's float32 dtype code
return executorch::aten::ScalarType::Float;
case 11: // PyTorch's bool dtype code
Expand Down
13 changes: 11 additions & 2 deletions backends/apple/metal/runtime/shims/et_metal_ops.mm
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,15 @@ AOTITorchError aoti_torch_mps_mm_out(
MPSDataType mps_dtype;
size_t element_size;

ET_LOG(Debug, "aoti_torch_mps_mm_out: self_tensor scalar_type=%d, SupportedDTypes::FLOAT32=%d, SupportedDTypes::BFLOAT16=%d",
dtype, static_cast<int32_t>(SupportedDTypes::FLOAT32), static_cast<int32_t>(SupportedDTypes::BFLOAT16));
ET_LOG(Debug, "aoti_torch_mps_mm_out: self_tensor scalar_type=%d, SupportedDTypes::FLOAT32=%d, SupportedDTypes::FLOAT16=%d, SupportedDTypes::BFLOAT16=%d",
dtype, static_cast<int32_t>(SupportedDTypes::FLOAT32), static_cast<int32_t>(SupportedDTypes::FLOAT16), static_cast<int32_t>(SupportedDTypes::BFLOAT16));

if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very well familiar with ET coding practice, but considering you have no modify it in two places, why not have to_mps_dtype(SupportedDtypes) inline function and call it here and few hundrend lines down below?

mps_dtype = MPSDataTypeFloat32;
element_size = sizeof(float);
} else if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT16)) {
mps_dtype = MPSDataTypeFloat16;
element_size = sizeof(uint16_t); // half is 16 bits
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
mps_dtype = MPSDataTypeBFloat16;
element_size = sizeof(uint16_t); // bfloat16 is 16 bits
Expand Down Expand Up @@ -592,6 +595,9 @@ AOTITorchError aoti_torch_mps_convolution(
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
mps_dtype = MPSDataTypeFloat32;
element_size = sizeof(float);
} else if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT16)) {
mps_dtype = MPSDataTypeFloat16;
element_size = sizeof(uint16_t); // half is 16 bits
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
mps_dtype = MPSDataTypeBFloat16;
element_size = sizeof(uint16_t); // bfloat16 is 16 bits
Expand Down Expand Up @@ -1084,6 +1090,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
mps_dtype = MPSDataTypeFloat32;
element_size = sizeof(float);
} else if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT16)) {
mps_dtype = MPSDataTypeFloat16;
element_size = sizeof(uint16_t); // half is 16 bits
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
mps_dtype = MPSDataTypeBFloat16;
element_size = sizeof(uint16_t); // bfloat16 is 16 bits
Expand Down
2 changes: 2 additions & 0 deletions backends/apple/metal/runtime/shims/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ bool is_dtype_supported_in_et_metal(int32_t dtype) {
switch (dtype) {
case static_cast<int32_t>(SupportedDTypes::INT64):
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
case static_cast<int32_t>(SupportedDTypes::FLOAT16):
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
return true;
default:
Expand All @@ -40,6 +41,7 @@ AOTITorchError validate_dtype(int32_t dtype) {
dtype,
static_cast<int32_t>(SupportedDTypes::INT64),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDTypes::FLOAT16),
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
return Error::InvalidArgument;
}
Expand Down
2 changes: 1 addition & 1 deletion backends/apple/metal/runtime/shims/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ enum class SupportedDTypes : int32_t {
// INT16 = 2, // PyTorch's int16 dtype code
// INT32 = 3, // PyTorch's int32 dtype code
INT64 = 4, // PyTorch's int64 dtype code
// FLOAT16 = 5, // PyTorch's float16 dtype code
FLOAT16 = 5, // PyTorch's float16 dtype code
FLOAT32 = 6, // PyTorch's float32 dtype code
// FLOAT64 = 7, // PyTorch's float64 dtype code
// BOOL = 11, // PyTorch's bool dtype code
Expand Down
19 changes: 18 additions & 1 deletion extension/asr/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,24 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
Info,
"Conversion complete, first value = %f",
static_cast<float>(
preprocessed_features->mutable_data_ptr<float>()[0]));
preprocessed_features
->mutable_data_ptr<::executorch::aten::BFloat16>()[0]));
} else if (expected_dtype == ::executorch::aten::ScalarType::Half) {
ET_LOG(
Info,
"Converting audio features from %s to Float16 (Half). Before converting, first value = %f",
::executorch::runtime::toString(preprocessed_features->scalar_type()),
preprocessed_features->mutable_data_ptr<float>()[0]);
auto convert_result = ::executorch::extension::llm::convert_to_float16(
preprocessed_features);
ET_CHECK_OK_OR_RETURN_ERROR(convert_result.error());
preprocessed_features = convert_result.get();
ET_LOG(
Info,
"Conversion complete, first value = %f",
static_cast<float>(
preprocessed_features
->mutable_data_ptr<::executorch::aten::Half>()[0]));
}
}

Expand Down
24 changes: 24 additions & 0 deletions extension/llm/runner/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,30 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) {
return bf16_tensor;
}

/**
* Helper function to convert a float tensor to float16 (Half).
* Creates a new tensor with Half dtype and copies/converts the data.
*/
inline ::executorch::runtime::Result<::executorch::extension::TensorPtr>
convert_to_float16(const ::executorch::extension::TensorPtr& src_tensor) {
ET_CHECK_OR_RETURN_ERROR(
src_tensor->scalar_type() == ::executorch::aten::ScalarType::Float,
InvalidArgument,
"Float16 conversion only supported from Float source data");

const auto num_elements = static_cast<size_t>(src_tensor->numel());
const float* float_data = src_tensor->const_data_ptr<float>();

auto f16_tensor = ::executorch::extension::empty_like(
src_tensor, ::executorch::aten::ScalarType::Half);
auto* f16_data = f16_tensor->mutable_data_ptr<::executorch::aten::Half>();
for (size_t i = 0; i < num_elements; ++i) {
f16_data[i] = ::executorch::aten::Half(float_data[i]);
}

return f16_tensor;
}

} // namespace llm
} // namespace extension
} // namespace executorch
Comment on lines 201 to 203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: if ET is a C++17 compatible project, why not use nested namespaces?

Expand Down
Loading