@@ -3,44 +3,6 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
33load ("@fbsource//xplat/caffe2:pt_ops.bzl" , "pt_operator_library" )
44load ("@fbsource//xplat/executorch/build:runtime_wrapper.bzl" , "runtime" )
55
6- def define_test_targets (test_name , extra_deps = [], src_file = None , is_fbcode = False ):
7- deps_list = [
8- "//third-party/googletest:gtest_main" ,
9- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
10- runtime .external_dep_location ("libtorch" ),
11- ] + extra_deps
12-
13- src_file_str = src_file if src_file else "{}.cpp" .format (test_name )
14-
15- runtime .cxx_binary (
16- name = "{}_bin" .format (test_name ),
17- srcs = [
18- src_file_str ,
19- ],
20- compiler_flags = [
21- "-Wno-unused-variable" ,
22- ],
23- define_static_target = False ,
24- deps = deps_list ,
25- )
26-
27- runtime .cxx_test (
28- name = test_name ,
29- srcs = [
30- src_file_str ,
31- ],
32- 33- fbandroid_additional_loaded_sonames = [
34- "torch-code-gen" ,
35- "vulkan_graph_runtime" ,
36- "vulkan_graph_runtime_shaderlib" ,
37- ],
38- platforms = [ANDROID ],
39- use_instrumentation_test = True ,
40- deps = deps_list ,
41- )
42-
43-
446def define_common_targets (is_fbcode = False ):
457 if is_fbcode :
468 return
@@ -120,6 +82,19 @@ def define_common_targets(is_fbcode = False):
12082 default_outs = ["." ],
12183 )
12284
85+ runtime .cxx_binary (
86+ name = "compute_graph_op_tests_bin" ,
87+ srcs = [
88+ ":generated_op_correctness_tests_cpp[op_tests.cpp]" ,
89+ ],
90+ define_static_target = False ,
91+ deps = [
92+ "//third-party/googletest:gtest_main" ,
93+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
94+ runtime .external_dep_location ("libtorch" ),
95+ ],
96+ )
97+
12398 runtime .cxx_binary (
12499 name = "compute_graph_op_benchmarks_bin" ,
125100 srcs = [
@@ -136,17 +111,136 @@ def define_common_targets(is_fbcode = False):
136111 ],
137112 )
138113
139- define_test_targets (
140- "compute_graph_op_tests" ,
141- src_file = ":generated_op_correctness_tests_cpp[op_tests.cpp]"
114+ runtime .cxx_test (
115+ name = "compute_graph_op_tests" ,
116+ srcs = [
117+ ":generated_op_correctness_tests_cpp[op_tests.cpp]" ,
118+ ],
119+ 120+ fbandroid_additional_loaded_sonames = [
121+ "torch-code-gen" ,
122+ "vulkan_graph_runtime" ,
123+ "vulkan_graph_runtime_shaderlib" ,
124+ ],
125+ platforms = [ANDROID ],
126+ use_instrumentation_test = True ,
127+ deps = [
128+ "//third-party/googletest:gtest_main" ,
129+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
130+ runtime .external_dep_location ("libtorch" ),
131+ ],
142132 )
143133
144- define_test_targets (
145- "sdpa_test" ,
146- extra_deps = [
134+
135+ runtime .cxx_binary (
136+ name = "sdpa_test_bin" ,
137+ srcs = [
138+ "sdpa_test.cpp" ,
139+ ],
140+ compiler_flags = [
141+ "-Wno-unused-variable" ,
142+ ],
143+ define_static_target = False ,
144+ deps = [
145+ "//third-party/googletest:gtest_main" ,
146+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
147+ "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
148+ ],
149+ )
150+
151+ runtime .cxx_test (
152+ name = "sdpa_test" ,
153+ srcs = [
154+ "sdpa_test.cpp" ,
155+ ],
156+ 157+ fbandroid_additional_loaded_sonames = [
158+ "torch-code-gen" ,
159+ "vulkan_graph_runtime" ,
160+ "vulkan_graph_runtime_shaderlib" ,
161+ ],
162+ platforms = [ANDROID ],
163+ use_instrumentation_test = True ,
164+ deps = [
165+ "//third-party/googletest:gtest_main" ,
166+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
167+ "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
168+ "//executorch/extension/tensor:tensor" ,
169+ runtime .external_dep_location ("libtorch" ),
170+ ],
171+ )
172+
173+ runtime .cxx_binary (
174+ name = "linear_weight_int4_test_bin" ,
175+ srcs = [
176+ "linear_weight_int4_test.cpp" ,
177+ ],
178+ compiler_flags = [
179+ "-Wno-unused-variable" ,
180+ ],
181+ define_static_target = False ,
182+ deps = [
183+ "//third-party/googletest:gtest_main" ,
184+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
185+ runtime .external_dep_location ("libtorch" ),
186+ ],
187+ )
188+
189+ runtime .cxx_test (
190+ name = "linear_weight_int4_test" ,
191+ srcs = [
192+ "linear_weight_int4_test.cpp" ,
193+ ],
194+ 195+ fbandroid_additional_loaded_sonames = [
196+ "torch-code-gen" ,
197+ "vulkan_graph_runtime" ,
198+ "vulkan_graph_runtime_shaderlib" ,
199+ ],
200+ platforms = [ANDROID ],
201+ use_instrumentation_test = True ,
202+ deps = [
203+ "//third-party/googletest:gtest_main" ,
204+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
147205 "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
148206 "//executorch/extension/tensor:tensor" ,
149- ]
207+ runtime .external_dep_location ("libtorch" ),
208+ ],
209+ )
210+
211+ runtime .cxx_binary (
212+ name = "rotary_embedding_test_bin" ,
213+ srcs = [
214+ "rotary_embedding_test.cpp" ,
215+ ],
216+ compiler_flags = [
217+ "-Wno-unused-variable" ,
218+ ],
219+ define_static_target = False ,
220+ deps = [
221+ "//third-party/googletest:gtest_main" ,
222+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
223+ runtime .external_dep_location ("libtorch" ),
224+ ],
225+ )
226+
227+ runtime .cxx_test (
228+ name = "rotary_embedding_test" ,
229+ srcs = [
230+ "rotary_embedding_test.cpp" ,
231+ ],
232+ 233+ fbandroid_additional_loaded_sonames = [
234+ "torch-code-gen" ,
235+ "vulkan_graph_runtime" ,
236+ "vulkan_graph_runtime_shaderlib" ,
237+ ],
238+ platforms = [ANDROID ],
239+ use_instrumentation_test = True ,
240+ deps = [
241+ "//third-party/googletest:gtest_main" ,
242+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
243+ "//executorch/extension/tensor:tensor" ,
244+ runtime .external_dep_location ("libtorch" ),
245+ ],
150246 )
151- define_test_targets ("linear_weight_int4_test" )
152- define_test_targets ("rotary_embedding_test" )
0 commit comments