Skip to content
Merged
22 changes: 2 additions & 20 deletions backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include "test_utils.h"

#include <cassert>

//
Expand Down Expand Up @@ -201,26 +203,6 @@ void test_reference_linear_qcs4w(
ASSERT_TRUE(at::allclose(out, out_ref));
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_vulkan_linear_qga4w_impl(
const int B,
const int M,
Expand Down
22 changes: 2 additions & 20 deletions backends/vulkan/test/op_tests/rotary_embedding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include "test_utils.h"

#include <cassert>

//
Expand Down Expand Up @@ -55,26 +57,6 @@ std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
// Test functions
//

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_reference(
const int n_heads = 4,
const int n_kv_heads = 2,
Expand Down
20 changes: 2 additions & 18 deletions backends/vulkan/test/op_tests/sdpa_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/extension/llm/custom_ops/op_sdpa.h>

#include "test_utils.h"

#include <cassert>
#include <iostream>

Expand Down Expand Up @@ -261,24 +263,6 @@ void test_reference_sdpa(
}
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_vulkan_sdpa(
const int start_input_pos,
const int base_sequence_len,
Expand Down
46 changes: 44 additions & 2 deletions backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ def define_common_targets(is_fbcode = False):
],
)

runtime.cxx_library(
name = "test_utils",
srcs = [
"test_utils.cpp",
],
headers = [
"test_utils.h",
],
exported_headers = [
"test_utils.h",
],
deps = [
"//executorch/backends/vulkan:vulkan_graph_runtime",
"//executorch/runtime/core/exec_aten:lib",
runtime.external_dep_location("libtorch"),
],
visibility = [
"//executorch/backends/vulkan/test/op_tests/...",
"@EXECUTORCH_CLIENTS",
],
)

define_test_targets(
"compute_graph_op_tests",
src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
Expand All @@ -144,9 +166,29 @@ def define_common_targets(is_fbcode = False):
define_test_targets(
"sdpa_test",
extra_deps = [
":test_utils",
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/extension/tensor:tensor",
]
)
define_test_targets("linear_weight_int4_test")
define_test_targets("rotary_embedding_test")
define_test_targets(
"quantize_test",
extra_deps = [
":test_utils",
"//executorch/kernels/quantized/cpu:op_quantize",
"//executorch/extension/tensor:tensor",
"//executorch/extension/aten_util:aten_bridge",
]
)
define_test_targets(
"linear_weight_int4_test",
extra_deps = [
":test_utils",
]
)
define_test_targets(
"rotary_embedding_test",
extra_deps = [
":test_utils",
]
)
114 changes: 114 additions & 0 deletions backends/vulkan/test/op_tests/test_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "test_utils.h"

#include <stdexcept>

executorch::aten::ScalarType at_scalartype_to_et_scalartype(
at::ScalarType dtype) {
using ScalarType = executorch::aten::ScalarType;
switch (dtype) {
case at::kByte:
return ScalarType::Byte;
case at::kChar:
return ScalarType::Char;
case at::kShort:
return ScalarType::Short;
case at::kInt:
return ScalarType::Int;
case at::kLong:
return ScalarType::Long;
case at::kHalf:
return ScalarType::Half;
case at::kFloat:
return ScalarType::Float;
case at::kDouble:
return ScalarType::Double;
default:
throw std::runtime_error("Unsupported dtype");
}
}

std::string scalar_type_name(c10::ScalarType dtype) {
switch (dtype) {
case c10::kLong:
return "c10::kLong";
case c10::kShort:
return "c10::kShort";
case c10::kComplexHalf:
return "c10::kComplexHalf";
case c10::kComplexFloat:
return "c10::kComplexFloat";
case c10::kComplexDouble:
return "c10::kComplexDouble";
case c10::kBool:
return "c10::kBool";
case c10::kQInt8:
return "c10::kQInt8";
case c10::kQUInt8:
return "c10::kQUInt8";
case c10::kQInt32:
return "c10::kQInt32";
case c10::kBFloat16:
return "c10::kBFloat16";
case c10::kQUInt4x2:
return "c10::kQUInt4x2";
case c10::kQUInt2x4:
return "c10::kQUInt2x4";
case c10::kFloat:
return "c10::kFloat";
case c10::kHalf:
return "c10::kHalf";
case c10::kInt:
return "c10::kInt";
case c10::kChar:
return "c10::kChar";
case c10::kByte:
return "c10::kByte";
case c10::kDouble:
return "c10::kDouble";
case c10::kUInt16:
return "c10::kUInt16";
case c10::kBits16:
return "c10::kBits16";
default:
return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
}
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kHalf:
return vkapi::kHalf;
case c10::kFloat:
return vkapi::kFloat;
case c10::kDouble:
return vkapi::kDouble;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kLong;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
case c10::kShort:
return vkapi::kShort;
case c10::kUInt16:
return vkapi::kUInt16;
default:
VK_THROW(
"Unsupported at::ScalarType: ",
scalar_type_name(at_scalartype),
" (",
static_cast<int>(at_scalartype),
")");
}
}
32 changes: 32 additions & 0 deletions backends/vulkan/test/op_tests/test_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <string>

#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>

/**
* Convert at::ScalarType to executorch::ScalarType
*/
executorch::aten::ScalarType at_scalartype_to_et_scalartype(
at::ScalarType dtype);

/**
* Get the string name of a c10::ScalarType for better error messages
*/
std::string scalar_type_name(c10::ScalarType dtype);

/**
* Convert c10::ScalarType to vkcompute::vkapi::ScalarType
*/
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype);
Loading