Skip to content

Commit cf52fe3

Browse files
LukeBoyercopybara-github
authored andcommitted
Add "model provider" pattern to litert device scripts. This is used to specify scripts that generate data dependencies at runtime vs specifying them directly. Wire this up with ats to support models downloaded from the public bucket.
LiteRT-PiperOrigin-RevId: 825802418
1 parent 3792211 commit cf52fe3

File tree

12 files changed

+279
-47
lines changed

12 files changed

+279
-47
lines changed

litert/ats/BUILD

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ load("@rules_cc//cc:cc_test.bzl", "cc_test")
1717
load("//litert/ats:ats.bzl", "litert_define_ats")
1818
load("//litert/build_common:litert_build_defs.bzl", "litert_test")
1919
load("//litert/integration_test:litert_device.bzl", "litert_device_test")
20+
load("//litert/integration_test:litert_device_script.bzl", "make_download_model_provider")
2021

2122
package(
2223
# copybara:uncomment default_applicable_licenses = ["@org_tensorflow//tensorflow:license"],
@@ -328,6 +329,11 @@ cc_library(
328329

329330
# PRE-CONFIGURED CTS SUITES ########################################################################
330331

332+
make_download_model_provider(
333+
name = "ats_models_provider",
334+
url = "https://storage.googleapis.com/litert/ats_models.tar.gz",
335+
)
336+
331337
litert_define_ats(
332338
name = "sample_cpu_ats",
333339
backend = "cpu",
@@ -340,8 +346,9 @@ litert_define_ats(
340346
backend = "example",
341347
compile_only_suffix = "_aot",
342348
do_register = [
343-
"sub.*f32",
344-
"mul.*f32",
349+
# "sub.*f32",
350+
# "mul.*f32",
351+
".*ExtraModel.*",
345352
],
346353
extra_flags = ["--limit=1"],
347354
jit_suffix = "",

litert/ats/ats.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def litert_define_ats(
8989
if compile_aot_and_run_suffix:
9090
fail("Compile aot and run on device is not supported yet.")
9191

92-
init_run_args = []
92+
# TODO: Unify local workdir paths for scripting.
93+
init_run_args = ["--extra_models={}".format("/data/local/tmp/runfiles/user/tmp/litert_extras")]
9394
if is_npu_backend(backend):
9495
init_run_args += [
9596
"--dispatch_dir=\"{}\"".format(dispatch_device_rlocation(backend)),
@@ -114,6 +115,7 @@ def litert_define_ats(
114115
local_suffix = "",
115116
exec_args = run_args,
116117
backend_id = backend,
118+
model_providers = ["//litert/integration_test:ats_models_provider"],
117119
)
118120

119121
compile_args = _make_ats_args(
@@ -135,4 +137,5 @@ def litert_define_ats(
135137
exec_args = compile_args,
136138
build_for_host = True,
137139
build_for_device = False,
140+
model_providers = ["//litert/integration_test:ats_models_provider"],
138141
)

litert/ats/ats_aot.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
source "${0%.*}_lib.sh" || exit 1
1818

19-
models_out="/tmp/ats_models"
19+
# TODO: Unify workdirs with other scripts.
20+
readonly models_out="/tmp/litert_extras/ats"
2021
readonly exec_args=("${@:1}")
2122

2223
compile_bin=""
@@ -29,6 +30,7 @@ compiler_libs=()
2930

3031
function setup_context() {
3132
mkdir -p "$models_out"
33+
rm -rf "$models_out"/*
3234

3335
local in_flags=$1
3436
for a in ${in_flags[@]}; do
@@ -67,6 +69,15 @@ function setup_context() {
6769
link_path=$(dirname ${lib}):${link_path}
6870
fi
6971
done
72+
73+
local input_models=($(get_provided_models))
74+
if [[ $? -ne 0 ]]; then
75+
fatal "Failed to get provided models."
76+
fi
77+
78+
if [[ -n "${input_models[*]}" ]]; then
79+
compile_args+=("--extra_models=$(str_join "," ${input_models[@]})")
80+
fi
7081
}
7182

7283
function print_args() {

litert/integration_test/BUILD

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ load("@rules_cc//cc:cc_test.bzl", "cc_test")
1919
load("//litert/build_common:special_rule.bzl", "litert_android_linkopts")
2020
load("//litert/integration_test:litert_device.bzl", "litert_device_exec", "litert_device_test", "litert_integration_test")
2121
load("//litert/integration_test:litert_device_common.bzl", "device_rlocation", "get_libs")
22-
load("//litert/integration_test:litert_device_script.bzl", "litert_device_script")
22+
load("//litert/integration_test:litert_device_script.bzl", "litert_device_script", "make_download_model_provider")
2323
# copybara:uncomment load("//litert/integration_test/google:litert_device_guitar.bzl", "litert_cpu_mh_guitar_test", "litert_mediatek_mh_guitar_test", "litert_pixel_9_mh_guitar_test", "litert_qualcomm_mh_guitar_test")
2424

2525
package(
@@ -427,6 +427,17 @@ cc_binary(
427427
],
428428
)
429429

430+
sh_binary(
431+
name = "dummy_model_provider",
432+
srcs = ["dummy_model_provider.sh"],
433+
)
434+
435+
sh_library(
436+
name = "device_script_common",
437+
srcs = ["device_script_common.sh"],
438+
visibility = ["//litert:litert_public"],
439+
)
440+
430441
litert_device_script(
431442
name = "check_script_device",
432443
testonly = True,
@@ -441,11 +452,40 @@ litert_device_script(
441452
exec_args = [
442453
"--check_device",
443454
],
455+
model_providers = [
456+
":dummy_model_provider",
457+
],
458+
script = "device_script_test.sh",
459+
)
460+
461+
litert_device_script(
462+
name = "check_script_host",
463+
testonly = True,
464+
backend_id = "example",
465+
bin = ":dummy_binary",
466+
build_for_device = False,
467+
build_for_host = True,
468+
data = [
469+
"//litert/test:testdata/mobilenet_v2_1.0_224.tflite",
470+
"//litert/test:testdata/simple_add_op_qc_v75_precompiled.tflite",
471+
],
472+
exec_args = [
473+
"--check_host",
474+
],
475+
model_providers = [
476+
":dummy_model_provider",
477+
],
444478
script = "device_script_test.sh",
445479
)
446480

481+
make_download_model_provider(
482+
name = "ats_models_provider",
483+
url = "https://storage.googleapis.com/litert/ats_models.tar.gz",
484+
)
485+
447486
litert_device_exec(
448487
name = "exec_for_testing_device",
488+
testonly = True,
449489
backend_id = "example",
450490
data = [
451491
"//litert/test:testdata/mobilenet_v2_1.0_224.tflite",
@@ -459,6 +499,9 @@ litert_device_exec(
459499
"--expected_libs_on_ld={}".format(",".join([device_rlocation(lib) for lib in get_libs("example")])),
460500
],
461501
local_suffix = "",
502+
model_providers = [
503+
":ats_models_provider",
504+
],
462505
remote_suffix = None,
463506
target = ":dummy_binary",
464507
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
#!/bin/bash
16+
17+
reset_color='\033[0m'
18+
host_color='\033[34m'
19+
hightlight_color='\033[36m'
20+
error_color='\033[31m'
21+
22+
# Print a message in the canonical color.
23+
function print() {
24+
echo -e "${host_color}${1}${reset_color}"
25+
}
26+
27+
# Print a message in the canonical hightlight color.
28+
function print_hightlight() {
29+
echo -e "${hightlight_color}${1}${reset_color}"
30+
}
31+
32+
# Print a file or host file device file pair.
33+
function print_file() {
34+
if [[ "$#" -ne 2 ]]; then
35+
echo -e " ${1}"
36+
else
37+
echo -e " ${1} => ${2}"
38+
fi
39+
}
40+
41+
# Print message and exit.
42+
function fatal() {
43+
echo -e "${error_color}ERROR: ${reset_color}${1}"
44+
exit 1
45+
}
46+
47+
# Join with delim.
48+
function str_join() {
49+
local IFS=$1
50+
shift
51+
echo "$*"
52+
}

litert/integration_test/device_script_template.sh

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#!/bin/bash
2-
31
# Copyright 2025 Google LLC.
42
#
53
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,10 +17,7 @@
1917
# Shell library for working with data and executable files from bzl between host
2018
# and device. Meant to be templated via litert_device_script.bzl.
2119

22-
reset_color='\033[0m'
23-
host_color='\033[34m'
24-
hightlight_color='\033[36m'
25-
error_color='\033[31m'
20+
source third_party/odml/litert/litert/integration_test/device_script_common.sh || exit 1
2621

2722
# Root of runfiles on the device.
2823
device_runfiles_root="/data/local/tmp/runfiles"
@@ -169,28 +164,17 @@ function find_device_runtime_lib() {
169164
done
170165
}
171166

172-
# Print a message in the canonical color.
173-
function print() {
174-
echo -e "${host_color}${1}${reset_color}"
175-
}
176-
177-
# Print a message in the canonical hightlight color.
178-
function print_hightlight() {
179-
echo -e "${hightlight_color}${1}${reset_color}"
180-
}
181-
182-
# Print a file or host file device file pair.
183-
function print_file() {
184-
if [[ "$#" -ne 2 ]]; then
185-
echo -e " ${1}"
167+
# Call any/all model provider scripts built with this tool. The return code
168+
# of this need to be checked by callers.
169+
function get_provided_models() {
170+
local model_providers=@@model_providers@@
171+
if [[ "$model_providers" == "@@"*"@@" ]]; then
172+
echo ""
186173
else
187-
echo -e " ${1} => ${2}"
174+
for provider in ${model_providers[@]}; do
175+
$provider
176+
done
188177
fi
189178
}
190179

191-
# Print message and exit.
192-
function fatal() {
193-
echo -e "${error_color}ERROR: ${reset_color}${1}"
194-
exit 1
195-
}
196180

litert/integration_test/device_script_test.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,19 @@ function check_len() {
7272
if (( ${#array[@]} == ${len} )); then
7373
echo "${tag} array len OK"
7474
else
75-
echo "${tag}: array len NOT OK"
75+
echo "${tag}: array len NOT OK, expected ${len}, got ${#array[@]}"
7676
exit 1
7777
fi
7878
}
7979

80+
provided_models=($(get_provided_models))
81+
if [ $? -ne 0 ]; then
82+
echo "Failed to get provided models."
83+
exit 1
84+
else
85+
check_len "provided_models" 2 ${provided_models[*]}
86+
fi
87+
8088
check_len "data_files" 2 $(data_files)
8189
for file in $(data_files); do
8290
check_file "$file" "data_files" ".tflite"
@@ -105,3 +113,6 @@ if [[ -n "$check_device" ]]; then
105113
check_file "$(find_device_dispatch)" "find_device_dispatch" "Example.so"
106114
check_file "$(find_device_runtime_lib)" "find_device_runtime_lib" "libLiteRtRuntimeCApi.so"
107115
fi
116+
117+
118+
# echo "$(get_provided_models)"
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
#!/bin/bash
16+
17+
source third_party/odml/litert/litert/integration_test/device_script_common.sh || exit 1
18+
19+
# TODO: Unify workdirs with other scripts.
20+
readonly work_dir="/tmp/litert_extras"
21+
mkdir -p "${work_dir}"
22+
23+
readonly url=@@url@@
24+
25+
if [[ "$url" == "@@"*"@@" ]]; then
26+
fatal "No url templated into the script."
27+
elif [[ -z "${url}" ]]; then
28+
fatal "Url is empty."
29+
fi
30+
31+
readonly target_file="${work_dir}/$(basename ${url})"
32+
33+
if [[ ${target_file} != *".tar.gz" ]]; then
34+
fatal "Target file is not a .tar.gz: ${target_file}"
35+
fi
36+
37+
wget -p -O ${target_file} ${url}
38+
if [[ $? -ne 0 ]]; then
39+
fatal "Failed to download model from ${url}."
40+
fi
41+
42+
models=($(tar -xhvf ${target_file} -C ${work_dir}))
43+
rm -f ${target_file}
44+
for model in ${models[@]}; do
45+
echo "${work_dir}/${model}"
46+
done
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2025 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
#!/bin/bash
16+
17+
echo "${work_dir}/dummy_model1.tflite"
18+
echo "${work_dir}/dummy_model2.tflite"

litert/integration_test/litert_device.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def litert_device_exec(
133133
exec_args = [],
134134
remote_suffix = "",
135135
local_suffix = "_adb",
136+
model_providers = [],
136137
testonly = True):
137138
"""
138139
Macro to execute a binary target on a device through adb.
@@ -159,12 +160,14 @@ def litert_device_exec(
159160
bin = target,
160161
script = "//litert/integration_test:mobile_install.sh",
161162
exec_args = exec_args,
163+
model_providers = model_providers,
162164
testonly = testonly,
163165
backend_id = backend_id,
164166
)
165167

166168
# Copybara comment doesn't work right if it is inside an if statement (breaks formatting).
167169
if remote_suffix != None:
170+
# Note model providers are not compatible with mobile harness.
168171
_litert_mh_exec(
169172
name = name + remote_suffix,
170173
target = target,

0 commit comments

Comments
 (0)