Skip to content

Commit 3b5463c

Browse files
author
morelos
committed
[ET-VK][Ops] common test utils for converting aten types to vulkan types
Pull Request resolved: #11575 # Context Most op test frameworks make use of common function utilities seeking to convert ATen scalartypes to ET vulkan scalartypes. In order to make this more accessible without having to constantly redefine these utility functions, this diff exists to migrate that. This will be re-used for quantization, dequantization, and choose_qparams operators. # Changes We migrate from_at_scalartype from existing test frameworks for operators. Currently this included `linear_weight_int4`, `rotary_embedding`, and `sdpa`. We also include more test utilities that are useful for understanding in plaintext the type error when throwing an exception. We need to modify the targets to ensure that the test utils is visible to all other files. ghstack-source-id: 290376486 @exported-using-ghexport Differential Revision: [D76464550](https://our.internmc.facebook.com/intern/diff/D76464550/)
1 parent 73fc9a9 commit 3b5463c

File tree

6 files changed

+187
-60
lines changed

6 files changed

+187
-60
lines changed

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1616

17+
#include "test_utils.h"
18+
1719
#include <cassert>
1820

1921
//
@@ -201,26 +203,6 @@ void test_reference_linear_qcs4w(
201203
ASSERT_TRUE(at::allclose(out, out_ref));
202204
}
203205

204-
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
205-
using namespace vkcompute;
206-
switch (at_scalartype) {
207-
case c10::kFloat:
208-
return vkapi::kFloat;
209-
case c10::kHalf:
210-
return vkapi::kHalf;
211-
case c10::kInt:
212-
return vkapi::kInt;
213-
case c10::kLong:
214-
return vkapi::kInt;
215-
case c10::kChar:
216-
return vkapi::kChar;
217-
case c10::kByte:
218-
return vkapi::kByte;
219-
default:
220-
VK_THROW("Unsupported at::ScalarType!");
221-
}
222-
}
223-
224206
void test_vulkan_linear_qga4w_impl(
225207
const int B,
226208
const int M,

backends/vulkan/test/op_tests/rotary_embedding_test.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1616

17+
#include "test_utils.h"
18+
1719
#include <cassert>
1820

1921
//
@@ -55,26 +57,6 @@ std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
5557
// Test functions
5658
//
5759

58-
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
59-
using namespace vkcompute;
60-
switch (at_scalartype) {
61-
case c10::kFloat:
62-
return vkapi::kFloat;
63-
case c10::kHalf:
64-
return vkapi::kHalf;
65-
case c10::kInt:
66-
return vkapi::kInt;
67-
case c10::kLong:
68-
return vkapi::kInt;
69-
case c10::kChar:
70-
return vkapi::kChar;
71-
case c10::kByte:
72-
return vkapi::kByte;
73-
default:
74-
VK_THROW("Unsupported at::ScalarType!");
75-
}
76-
}
77-
7860
void test_reference(
7961
const int n_heads = 4,
8062
const int n_kv_heads = 2,

backends/vulkan/test/op_tests/sdpa_test.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
1919
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
2020

21+
#include "test_utils.h"
22+
2123
#include <cassert>
2224
#include <iostream>
2325

@@ -261,24 +263,6 @@ void test_reference_sdpa(
261263
}
262264
}
263265

264-
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
265-
using namespace vkcompute;
266-
switch (at_scalartype) {
267-
case c10::kFloat:
268-
return vkapi::kFloat;
269-
case c10::kHalf:
270-
return vkapi::kHalf;
271-
case c10::kInt:
272-
return vkapi::kInt;
273-
case c10::kLong:
274-
return vkapi::kInt;
275-
case c10::kChar:
276-
return vkapi::kChar;
277-
default:
278-
VK_THROW("Unsupported at::ScalarType!");
279-
}
280-
}
281-
282266
void test_vulkan_sdpa(
283267
const int start_input_pos,
284268
const int base_sequence_len,

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False):
142142
platforms = get_platforms(),
143143
)
144144

145+
runtime.cxx_library(
146+
name = "test_utils",
147+
srcs = [
148+
"test_utils.cpp",
149+
],
150+
headers = [
151+
"test_utils.h",
152+
],
153+
exported_headers = [
154+
"test_utils.h",
155+
],
156+
deps = [
157+
"//executorch/backends/vulkan:vulkan_graph_runtime",
158+
"//executorch/runtime/core/exec_aten:lib",
159+
runtime.external_dep_location("libtorch"),
160+
],
161+
visibility = [
162+
"//executorch/backends/vulkan/test/op_tests/...",
163+
"@EXECUTORCH_CLIENTS",
164+
],
165+
)
166+
145167
define_test_targets(
146168
"compute_graph_op_tests",
147169
src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
@@ -150,9 +172,20 @@ def define_common_targets(is_fbcode = False):
150172
define_test_targets(
151173
"sdpa_test",
152174
extra_deps = [
175+
":test_utils",
153176
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
154177
"//executorch/extension/tensor:tensor",
155178
]
156179
)
157-
define_test_targets("linear_weight_int4_test")
158-
define_test_targets("rotary_embedding_test")
180+
define_test_targets(
181+
"linear_weight_int4_test",
182+
extra_deps = [
183+
":test_utils",
184+
]
185+
)
186+
define_test_targets(
187+
"rotary_embedding_test",
188+
extra_deps = [
189+
":test_utils",
190+
]
191+
)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "test_utils.h"
10+
11+
#include <stdexcept>
12+
13+
executorch::aten::ScalarType at_scalartype_to_et_scalartype(
14+
at::ScalarType dtype) {
15+
using ScalarType = executorch::aten::ScalarType;
16+
switch (dtype) {
17+
case at::kByte:
18+
return ScalarType::Byte;
19+
case at::kChar:
20+
return ScalarType::Char;
21+
case at::kShort:
22+
return ScalarType::Short;
23+
case at::kInt:
24+
return ScalarType::Int;
25+
case at::kLong:
26+
return ScalarType::Long;
27+
case at::kHalf:
28+
return ScalarType::Half;
29+
case at::kFloat:
30+
return ScalarType::Float;
31+
case at::kDouble:
32+
return ScalarType::Double;
33+
default:
34+
throw std::runtime_error("Unsupported dtype");
35+
}
36+
}
37+
38+
std::string scalar_type_name(c10::ScalarType dtype) {
39+
switch (dtype) {
40+
case c10::kLong:
41+
return "c10::kLong";
42+
case c10::kShort:
43+
return "c10::kShort";
44+
case c10::kComplexHalf:
45+
return "c10::kComplexHalf";
46+
case c10::kComplexFloat:
47+
return "c10::kComplexFloat";
48+
case c10::kComplexDouble:
49+
return "c10::kComplexDouble";
50+
case c10::kBool:
51+
return "c10::kBool";
52+
case c10::kQInt8:
53+
return "c10::kQInt8";
54+
case c10::kQUInt8:
55+
return "c10::kQUInt8";
56+
case c10::kQInt32:
57+
return "c10::kQInt32";
58+
case c10::kBFloat16:
59+
return "c10::kBFloat16";
60+
case c10::kQUInt4x2:
61+
return "c10::kQUInt4x2";
62+
case c10::kQUInt2x4:
63+
return "c10::kQUInt2x4";
64+
case c10::kFloat:
65+
return "c10::kFloat";
66+
case c10::kHalf:
67+
return "c10::kHalf";
68+
case c10::kInt:
69+
return "c10::kInt";
70+
case c10::kChar:
71+
return "c10::kChar";
72+
case c10::kByte:
73+
return "c10::kByte";
74+
case c10::kDouble:
75+
return "c10::kDouble";
76+
case c10::kUInt16:
77+
return "c10::kUInt16";
78+
case c10::kBits16:
79+
return "c10::kBits16";
80+
default:
81+
return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
82+
}
83+
}
84+
85+
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
86+
using namespace vkcompute;
87+
switch (at_scalartype) {
88+
case c10::kHalf:
89+
return vkapi::kHalf;
90+
case c10::kFloat:
91+
return vkapi::kFloat;
92+
case c10::kDouble:
93+
return vkapi::kDouble;
94+
case c10::kInt:
95+
return vkapi::kInt;
96+
case c10::kLong:
97+
return vkapi::kLong;
98+
case c10::kChar:
99+
return vkapi::kChar;
100+
case c10::kByte:
101+
return vkapi::kByte;
102+
case c10::kShort:
103+
return vkapi::kShort;
104+
case c10::kUInt16:
105+
return vkapi::kUInt16;
106+
default:
107+
VK_THROW(
108+
"Unsupported at::ScalarType: ",
109+
scalar_type_name(at_scalartype),
110+
" (",
111+
static_cast<int>(at_scalartype),
112+
")");
113+
}
114+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <string>
12+
13+
#include <ATen/ATen.h>
14+
#include <c10/core/ScalarType.h>
15+
#include <executorch/backends/vulkan/runtime/api/api.h>
16+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
17+
18+
/**
19+
* Convert at::ScalarType to executorch::ScalarType
20+
*/
21+
executorch::aten::ScalarType at_scalartype_to_et_scalartype(
22+
at::ScalarType dtype);
23+
24+
/**
25+
* Get the string name of a c10::ScalarType for better error messages
26+
*/
27+
std::string scalar_type_name(c10::ScalarType dtype);
28+
29+
/**
30+
* Convert c10::ScalarType to vkcompute::vkapi::ScalarType
31+
*/
32+
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype);

0 commit comments

Comments
 (0)