Skip to content

Commit 3248c72

Browse files
committed
✨ Introduce Apple MPS backend support and testing
1 parent 1c5c7ad commit 3248c72

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

.ci/scripts/test_model.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,15 @@ test_model_with_coreml() {
246246
test_model_with_mps() {
247247
"${PYTHON_EXECUTABLE}" -m examples.apple.mps.scripts.mps_example --model_name="${MODEL_NAME}" --use_fp16
248248
EXPORTED_MODEL=$(find "." -type f -name "${MODEL_NAME}*.pte" -print -quit)
249+
250+
if [ -z "$EXPORTED_MODEL" ]; then
251+
echo "[error] failed to export model: no .pte file found"
252+
exit 1
253+
fi
254+
255+
echo "Testing exported model with mps_executor_runner..."
256+
./examples/apple/mps/scripts/build_mps_executor_runner.sh
257+
./cmake-out/examples/apple/mps/mps_executor_runner --model_path "${EXPORTED_MODEL}"
249258
}
250259

251260
if [[ "${BACKEND}" == "portable" ]]; then

.ci/scripts/wheel/test_macos.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
if __name__ == "__main__":
1212
test_base.run_tests(
1313
model_tests=[
14+
# test_base.ModelTest(
15+
# model=Model.Mv3,
16+
# backend=Backend.XnnpackQuantizationDelegation,
17+
# ),
18+
# test_base.ModelTest(
19+
# model=Model.Mv3,
20+
# backend=Backend.CoreMlExportAndTest,
21+
# ),
1422
test_base.ModelTest(
1523
model=Model.Mv3,
16-
backend=Backend.XnnpackQuantizationDelegation,
17-
),
18-
test_base.ModelTest(
19-
model=Model.Mv3,
20-
backend=Backend.CoreMlExportAndTest,
24+
backend=Backend.AppleMPS,
2125
),
2226
]
2327
)

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class Backend(str, Enum):
4646
XnnpackQuantizationDelegation = "xnnpack-quantization-delegation"
4747
CoreMlExportOnly = "coreml"
4848
CoreMlExportAndTest = "coreml-test" # AOT export + test with runner
49+
AppleMPS = "mps"
4950

5051
def __str__(self) -> str:
5152
return self.value

0 commit comments

Comments
 (0)