Skip to content

Commit 57b311a

Browse files
author
morelos
committed
[ET-VK][Ops] dequantize ops skeleton test framework
Skeleton framework that is needed to build out the dequantize_per_tensor and dequantize_per_token operators based on cpu implementation Differential Revision: [D76267021](https://our.internmc.facebook.com/intern/diff/D76267021/) [ghstack-poisoned]
1 parent f2c2380 commit 57b311a

File tree

2 files changed

+304
-0
lines changed

2 files changed

+304
-0
lines changed
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
18+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19+
20+
#include <cassert>
21+
#include <iostream>
22+
23+
namespace torch {
24+
namespace executor {
25+
namespace native {
26+
27+
// Forward declarations of the functions we're testing
28+
Tensor& dequantize_per_tensor_out(
29+
const Tensor& input,
30+
double scale,
31+
int64_t zero_point,
32+
int64_t quant_min,
33+
int64_t quant_max,
34+
ScalarType dtype,
35+
executorch::aten::optional<ScalarType> out_dtype,
36+
Tensor& out);
37+
38+
Tensor& dequantize_per_token_out(
39+
const Tensor& input,
40+
const Tensor& scale,
41+
const Tensor& zero_points,
42+
int64_t quant_min,
43+
int64_t quant_max,
44+
ScalarType dtype,
45+
ScalarType out_dtype,
46+
Tensor& out);
47+
48+
// Wrapper function for dequantize_per_tensor_out without context
49+
Tensor& dequantize_per_tensor_out_no_context(
50+
const Tensor& input,
51+
double scale,
52+
int64_t zero_point,
53+
int64_t quant_min,
54+
int64_t quant_max,
55+
ScalarType dtype,
56+
executorch::aten::optional<ScalarType> out_dtype,
57+
Tensor& out) {
58+
return torch::executor::native::dequantize_per_tensor_out(
59+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
60+
}
61+
62+
// Wrapper function for dequantize_per_token_out without context
63+
Tensor& dequantize_per_token_out_no_context(
64+
const Tensor& input,
65+
const Tensor& scale,
66+
const Tensor& zero_points,
67+
int64_t quant_min,
68+
int64_t quant_max,
69+
ScalarType dtype,
70+
ScalarType out_dtype,
71+
Tensor& out) {
72+
return torch::executor::native::dequantize_per_token_out(
73+
input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out);
74+
}
75+
76+
// ATen wrapper for dequantize_per_tensor
77+
at::Tensor dequantize_per_tensor_aten(
78+
const at::Tensor& input,
79+
double scale,
80+
int64_t zero_point,
81+
int64_t quant_min,
82+
int64_t quant_max,
83+
at::ScalarType dtype,
84+
at::ScalarType out_dtype) {
85+
auto out = at::empty_like(input, out_dtype);
86+
// Convert at::ScalarType to executorch::ScalarType
87+
ScalarType et_dtype;
88+
ScalarType et_out_dtype;
89+
90+
switch (dtype) {
91+
case at::kByte:
92+
et_dtype = ScalarType::Byte;
93+
break;
94+
case at::kChar:
95+
et_dtype = ScalarType::Char;
96+
break;
97+
case at::kShort:
98+
et_dtype = ScalarType::Short;
99+
break;
100+
case at::kInt:
101+
et_dtype = ScalarType::Int;
102+
break;
103+
case at::kLong:
104+
et_dtype = ScalarType::Long;
105+
break;
106+
default:
107+
throw std::runtime_error("Unsupported dtype");
108+
}
109+
110+
switch (out_dtype) {
111+
case at::kFloat:
112+
et_out_dtype = ScalarType::Float;
113+
break;
114+
case at::kDouble:
115+
et_out_dtype = ScalarType::Double;
116+
break;
117+
default:
118+
throw std::runtime_error("Unsupported out_dtype");
119+
}
120+
121+
executorch::aten::optional<ScalarType> opt_et_out_dtype(et_out_dtype);
122+
123+
WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7)
124+
(input, scale, zero_point, quant_min, quant_max, et_dtype, opt_et_out_dtype, out);
125+
return out;
126+
}
127+
128+
// ATen wrapper for dequantize_per_token
129+
at::Tensor dequantize_per_token_aten(
130+
const at::Tensor& input,
131+
const at::Tensor& scale,
132+
const at::Tensor& zero_points,
133+
int64_t quant_min,
134+
int64_t quant_max,
135+
at::ScalarType dtype,
136+
at::ScalarType out_dtype) {
137+
auto out = at::empty_like(input, out_dtype);
138+
// Convert at::ScalarType to executorch::ScalarType
139+
ScalarType et_dtype;
140+
ScalarType et_out_dtype;
141+
142+
switch (dtype) {
143+
case at::kByte:
144+
et_dtype = ScalarType::Byte;
145+
break;
146+
case at::kChar:
147+
et_dtype = ScalarType::Char;
148+
break;
149+
case at::kShort:
150+
et_dtype = ScalarType::Short;
151+
break;
152+
case at::kInt:
153+
et_dtype = ScalarType::Int;
154+
break;
155+
case at::kLong:
156+
et_dtype = ScalarType::Long;
157+
break;
158+
default:
159+
throw std::runtime_error("Unsupported dtype");
160+
}
161+
162+
switch (out_dtype) {
163+
case at::kFloat:
164+
et_out_dtype = ScalarType::Float;
165+
break;
166+
case at::kDouble:
167+
et_out_dtype = ScalarType::Double;
168+
break;
169+
default:
170+
throw std::runtime_error("Unsupported out_dtype");
171+
}
172+
173+
WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7)
174+
(input, scale, zero_points, quant_min, quant_max, et_dtype, et_out_dtype, out);
175+
return out;
176+
}
177+
178+
} // namespace native
179+
} // namespace executor
180+
} // namespace torch
181+
182+
183+
//
184+
// Test functions
185+
//
186+
187+
// Helper function to get the name of a ScalarType for better error messages
188+
std::string scalar_type_name(c10::ScalarType dtype) {
189+
switch (dtype) {
190+
case c10::kLong:
191+
return "c10::kLong";
192+
case c10::kShort:
193+
return "c10::kShort";
194+
case c10::kComplexHalf:
195+
return "c10::kComplexHalf";
196+
case c10::kComplexFloat:
197+
return "c10::kComplexFloat";
198+
case c10::kComplexDouble:
199+
return "c10::kComplexDouble";
200+
case c10::kBool:
201+
return "c10::kBool";
202+
case c10::kQInt8:
203+
return "c10::kQInt8";
204+
case c10::kQUInt8:
205+
return "c10::kQUInt8";
206+
case c10::kQInt32:
207+
return "c10::kQInt32";
208+
case c10::kBFloat16:
209+
return "c10::kBFloat16";
210+
case c10::kQUInt4x2:
211+
return "c10::kQUInt4x2";
212+
case c10::kQUInt2x4:
213+
return "c10::kQUInt2x4";
214+
default:
215+
return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
216+
}
217+
}
218+
219+
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
220+
using namespace vkcompute;
221+
switch (at_scalartype) {
222+
case c10::kFloat:
223+
return vkapi::kFloat;
224+
case c10::kHalf:
225+
return vkapi::kHalf;
226+
case c10::kInt:
227+
return vkapi::kInt;
228+
case c10::kLong:
229+
// We don't have inherent vkapi::kLong, use kInt instead
230+
return vkapi::kInt;
231+
case c10::kChar:
232+
return vkapi::kChar;
233+
case c10::kByte:
234+
return vkapi::kByte;
235+
case c10::kDouble:
236+
return vkapi::kDouble;
237+
case c10::kShort:
238+
return vkapi::kShort;
239+
case c10::kUInt16:
240+
return vkapi::kUInt16;
241+
default:
242+
VK_THROW(
243+
"Unsupported at::ScalarType: ",
244+
scalar_type_name(at_scalartype),
245+
" (",
246+
static_cast<int>(at_scalartype),
247+
")");
248+
}
249+
}
250+
251+
void check_dequantize_args(
252+
int64_t quant_min,
253+
int64_t quant_max,
254+
c10::ScalarType in_dtype,
255+
c10::ScalarType out_dtype) {
256+
using namespace vkcompute;
257+
258+
// Check that quant_min <= quant_max
259+
VK_CHECK_COND(
260+
quant_min <= quant_max,
261+
"quant_min must be <= quant_max, got quant_min: ",
262+
quant_min,
263+
" quant_max: ",
264+
quant_max);
265+
266+
// Check that input dtype is a quantized type
267+
switch (in_dtype) {
268+
case c10::kByte:
269+
case c10::kChar:
270+
case c10::kShort:
271+
case c10::kInt:
272+
case c10::kLong:
273+
break;
274+
default:
275+
VK_THROW(
276+
"Unsupported input dtype: ",
277+
scalar_type_name(in_dtype),
278+
" (",
279+
static_cast<int>(in_dtype),
280+
")");
281+
}
282+
283+
// Check that output dtype is a floating point type
284+
switch (out_dtype) {
285+
case c10::kFloat:
286+
case c10::kDouble:
287+
break;
288+
default:
289+
VK_THROW(
290+
"Unsupported output dtype: ",
291+
scalar_type_name(out_dtype),
292+
" (",
293+
static_cast<int>(out_dtype),
294+
")");
295+
}
296+
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,5 +156,13 @@ def define_common_targets(is_fbcode = False):
156156
"//executorch/extension/aten_util:aten_bridge",
157157
]
158158
)
159+
define_test_targets(
160+
"dequantize_test",
161+
extra_deps = [
162+
"//executorch/kernels/quantized/cpu:op_dequantize",
163+
"//executorch/extension/tensor:tensor",
164+
"//executorch/extension/aten_util:aten_bridge",
165+
]
166+
)
159167
define_test_targets("linear_weight_int4_test")
160168
define_test_targets("rotary_embedding_test")

0 commit comments

Comments
 (0)