@@ -3,6 +3,44 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
33load ("@fbsource//xplat/caffe2:pt_ops.bzl" , "pt_operator_library" )
44load ("@fbsource//xplat/executorch/build:runtime_wrapper.bzl" , "runtime" )
55
6+ def define_test_targets (test_name , extra_deps = [], src_file = None , is_fbcode = False ):
7+ deps_list = [
8+ "//third-party/googletest:gtest_main" ,
9+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
10+ runtime .external_dep_location ("libtorch" ),
11+ ] + extra_deps
12+
13+ src_file_str = src_file if src_file else "{}.cpp" .format (test_name )
14+
15+ runtime .cxx_binary (
16+ name = "{}_bin" .format (test_name ),
17+ srcs = [
18+ src_file_str ,
19+ ],
20+ compiler_flags = [
21+ "-Wno-unused-variable" ,
22+ ],
23+ define_static_target = False ,
24+ deps = deps_list ,
25+ )
26+
27+ runtime .cxx_test (
28+ name = test_name ,
29+ srcs = [
30+ src_file_str ,
31+ ],
32+ 33+ fbandroid_additional_loaded_sonames = [
34+ "torch-code-gen" ,
35+ "vulkan_graph_runtime" ,
36+ "vulkan_graph_runtime_shaderlib" ,
37+ ],
38+ platforms = [ANDROID ],
39+ use_instrumentation_test = True ,
40+ deps = deps_list ,
41+ )
42+
43+
644def define_common_targets (is_fbcode = False ):
745 if is_fbcode :
846 return
@@ -82,19 +120,6 @@ def define_common_targets(is_fbcode = False):
82120 default_outs = ["." ],
83121 )
84122
85- runtime .cxx_binary (
86- name = "compute_graph_op_tests_bin" ,
87- srcs = [
88- ":generated_op_correctness_tests_cpp[op_tests.cpp]" ,
89- ],
90- define_static_target = False ,
91- deps = [
92- "//third-party/googletest:gtest_main" ,
93- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
94- runtime .external_dep_location ("libtorch" ),
95- ],
96- )
97-
98123 runtime .cxx_binary (
99124 name = "compute_graph_op_benchmarks_bin" ,
100125 srcs = [
@@ -111,135 +136,17 @@ def define_common_targets(is_fbcode = False):
111136 ],
112137 )
113138
114- runtime .cxx_test (
115- name = "compute_graph_op_tests" ,
116- srcs = [
117- ":generated_op_correctness_tests_cpp[op_tests.cpp]" ,
118- ],
119- 120- fbandroid_additional_loaded_sonames = [
121- "torch-code-gen" ,
122- "vulkan_graph_runtime" ,
123- "vulkan_graph_runtime_shaderlib" ,
124- ],
125- platforms = [ANDROID ],
126- use_instrumentation_test = True ,
127- deps = [
128- "//third-party/googletest:gtest_main" ,
129- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
130- runtime .external_dep_location ("libtorch" ),
131- ],
139+ define_test_targets (
140+ "compute_graph_op_tests" ,
141+ src_file = ":generated_op_correctness_tests_cpp[op_tests.cpp]"
132142 )
133143
134- runtime .cxx_binary (
135- name = "sdpa_test_bin" ,
136- srcs = [
137- "sdpa_test.cpp" ,
138- ],
139- compiler_flags = [
140- "-Wno-unused-variable" ,
141- ],
142- define_static_target = False ,
143- deps = [
144- "//third-party/googletest:gtest_main" ,
145- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
146- "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
147- ],
148- )
149-
150- runtime .cxx_test (
151- name = "sdpa_test" ,
152- srcs = [
153- "sdpa_test.cpp" ,
154- ],
155- 156- fbandroid_additional_loaded_sonames = [
157- "torch-code-gen" ,
158- "vulkan_graph_runtime" ,
159- "vulkan_graph_runtime_shaderlib" ,
160- ],
161- platforms = [ANDROID ],
162- use_instrumentation_test = True ,
163- deps = [
164- "//third-party/googletest:gtest_main" ,
165- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
166- "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
167- "//executorch/extension/tensor:tensor" ,
168- runtime .external_dep_location ("libtorch" ),
169- ],
170- )
171-
172- runtime .cxx_binary (
173- name = "linear_weight_int4_test_bin" ,
174- srcs = [
175- "linear_weight_int4_test.cpp" ,
176- ],
177- compiler_flags = [
178- "-Wno-unused-variable" ,
179- ],
180- define_static_target = False ,
181- deps = [
182- "//third-party/googletest:gtest_main" ,
183- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
184- runtime .external_dep_location ("libtorch" ),
185- ],
186- )
187-
188- runtime .cxx_test (
189- name = "linear_weight_int4_test" ,
190- srcs = [
191- "linear_weight_int4_test.cpp" ,
192- ],
193- 194- fbandroid_additional_loaded_sonames = [
195- "torch-code-gen" ,
196- "vulkan_graph_runtime" ,
197- "vulkan_graph_runtime_shaderlib" ,
198- ],
199- platforms = [ANDROID ],
200- use_instrumentation_test = True ,
201- deps = [
202- "//third-party/googletest:gtest_main" ,
203- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
144+ define_test_targets (
145+ "sdpa_test" ,
146+ extra_deps = [
204147 "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
205148 "//executorch/extension/tensor:tensor" ,
206- runtime .external_dep_location ("libtorch" ),
207- ],
208- )
209-
210- runtime .cxx_binary (
211- name = "rotary_embedding_test_bin" ,
212- srcs = [
213- "rotary_embedding_test.cpp" ,
214- ],
215- compiler_flags = [
216- "-Wno-unused-variable" ,
217- ],
218- define_static_target = False ,
219- deps = [
220- "//third-party/googletest:gtest_main" ,
221- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
222- runtime .external_dep_location ("libtorch" ),
223- ],
224- )
225-
226- runtime .cxx_test (
227- name = "rotary_embedding_test" ,
228- srcs = [
229- "rotary_embedding_test.cpp" ,
230- ],
231- 232- fbandroid_additional_loaded_sonames = [
233- "torch-code-gen" ,
234- "vulkan_graph_runtime" ,
235- "vulkan_graph_runtime_shaderlib" ,
236- ],
237- platforms = [ANDROID ],
238- use_instrumentation_test = True ,
239- deps = [
240- "//third-party/googletest:gtest_main" ,
241- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
242- "//executorch/extension/tensor:tensor" ,
243- runtime .external_dep_location ("libtorch" ),
244- ],
149+ ]
245150 )
151+ define_test_targets ("linear_weight_int4_test" )
152+ define_test_targets ("rotary_embedding_test" )
0 commit comments