Skip to content

Commit 9923be9

Browse files
author
morelos
committed
Update base for Update on "[ET-VK] double, short, and uint16 dtype runtime support"
Creating support for double, short, and uint16 for quantization ops. Registering the short keyword since theres already support. Also changing the cpu implementation to support half Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) [ghstack-poisoned]
2 parents e5e2fc6 + cbd3874 commit 9923be9

File tree

90 files changed

+914
-1171
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+914
-1171
lines changed

.lintrunner.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,6 @@ exclude_patterns = [
271271
'examples/**',
272272
'exir/verification/bindings.cpp',
273273
'extension/**',
274-
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
275-
'kernels/portable/cpu/util/elementwise_util.h',
276-
'kernels/portable/cpu/util/math_util.h',
277-
'kernels/portable/cpu/util/vectorized_math.h',
278274
'kernels/optimized/**',
279275
'runtime/core/exec_aten/**',
280276
# Want to be able to keep c10 in sync with PyTorch core.

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1515

16+
#include <executorch/backends/vulkan/runtime/vk_api/Runtime.h>
17+
1618
#include <executorch/runtime/backend/interface.h>
1719
#include <executorch/runtime/core/error.h>
1820
#include <executorch/runtime/core/evalue.h>
@@ -528,7 +530,9 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
528530
return Error::MemoryAllocationFailed;
529531
}
530532

531-
new (compute_graph) ComputeGraph(get_graph_config(compile_specs));
533+
GraphConfig graph_config = get_graph_config(compile_specs);
534+
graph_config.external_adapter = vkapi::set_and_get_external_adapter();
535+
new (compute_graph) ComputeGraph(graph_config);
532536

533537
Error err = compileModel(processed->data(), compute_graph);
534538

backends/vulkan/runtime/api/Context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
namespace vkcompute {
2525
namespace api {
2626

27-
Context::Context(size_t adapter_i, const ContextConfig& config)
27+
Context::Context(vkapi::Adapter* adapter, const ContextConfig& config)
2828
: config_(config),
2929
// Important handles
30-
adapter_p_(vkapi::runtime()->get_adapter_p(adapter_i)),
30+
adapter_p_(adapter),
3131
device_(adapter_p_->device_handle()),
3232
queue_(adapter_p_->request_queue()),
3333
// Resource pools
@@ -256,7 +256,7 @@ Context* context() {
256256
query_pool_config,
257257
};
258258

259-
return new Context(vkapi::runtime()->default_adapter_i(), config);
259+
return new Context(vkapi::runtime()->get_adapter_p(), config);
260260
} catch (...) {
261261
}
262262

backends/vulkan/runtime/api/Context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct ContextConfig final {
4242

4343
class Context final {
4444
public:
45-
explicit Context(size_t adapter_i, const ContextConfig&);
45+
explicit Context(vkapi::Adapter*, const ContextConfig&);
4646

4747
Context(const Context&) = delete;
4848
Context& operator=(const Context&) = delete;

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ ComputeGraph::ComputeGraph(GraphConfig config)
122122
prepack_descriptor_counts_{},
123123
execute_descriptor_counts_{},
124124
context_{new api::Context(
125-
vkapi::runtime()->default_adapter_i(),
125+
config.external_adapter ? config.external_adapter
126+
: vkapi::runtime()->get_adapter_p(),
126127
config_.context_config)},
127128
shared_objects_{},
128129
values_{},

backends/vulkan/runtime/graph/GraphConfig.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ GraphConfig::GraphConfig() {
6565
local_wg_size_override = {};
6666

6767
expect_dynamic_shapes = false;
68+
69+
external_adapter = nullptr;
6870
}
6971

7072
void GraphConfig::set_storage_type_override(utils::StorageType storage_type) {

backends/vulkan/runtime/graph/GraphConfig.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct GraphConfig final {
3636
// Whether or not the ComputeGraph should expect input shapes to be dynamic
3737
bool expect_dynamic_shapes;
3838

39+
vkapi::Adapter* external_adapter;
40+
3941
// Generate a default graph config with pre-configured settings
4042
explicit GraphConfig();
4143

backends/vulkan/runtime/vk_api/Adapter.cpp

Lines changed: 106 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,12 @@ namespace vkapi {
1717

1818
namespace {
1919

20-
VkDevice create_logical_device(
20+
void find_compute_queues(
2121
const PhysicalDevice& physical_device,
2222
const uint32_t num_queues_to_create,
23-
std::vector<Adapter::Queue>& queues,
24-
std::vector<uint32_t>& queue_usage) {
25-
// Find compute queues up to the requested number of queues
26-
27-
std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
23+
std::vector<VkDeviceQueueCreateInfo>& queue_create_infos,
24+
std::vector<std::pair<uint32_t, uint32_t>>& queues_to_get) {
2825
queue_create_infos.reserve(num_queues_to_create);
29-
30-
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
3126
queues_to_get.reserve(num_queues_to_create);
3227

3328
uint32_t remaining_queues = num_queues_to_create;
@@ -60,12 +55,44 @@ VkDevice create_logical_device(
6055
break;
6156
}
6257
}
58+
}
6359

60+
void populate_queue_info(
61+
const PhysicalDevice& physical_device,
62+
VkDevice logical_device,
63+
const std::vector<std::pair<uint32_t, uint32_t>>& queues_to_get,
64+
std::vector<Adapter::Queue>& queues,
65+
std::vector<uint32_t>& queue_usage) {
6466
queues.reserve(queues_to_get.size());
6567
queue_usage.reserve(queues_to_get.size());
6668

67-
// Create the VkDevice
69+
// Obtain handles for the created queues and initialize queue usage heuristic
70+
71+
for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
72+
VkQueue queue_handle = VK_NULL_HANDLE;
73+
VkQueueFlags flags =
74+
physical_device.queue_families.at(queue_idx.first).queueFlags;
75+
vkGetDeviceQueue(
76+
logical_device, queue_idx.first, queue_idx.second, &queue_handle);
77+
queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
78+
// Initial usage value
79+
queue_usage.push_back(0);
80+
}
81+
}
82+
83+
VkDevice create_logical_device(
84+
const PhysicalDevice& physical_device,
85+
const uint32_t num_queues_to_create,
86+
std::vector<Adapter::Queue>& queues,
87+
std::vector<uint32_t>& queue_usage) {
88+
// Find compute queues up to the requested number of queues
6889

90+
std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
91+
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
92+
find_compute_queues(
93+
physical_device, num_queues_to_create, queue_create_infos, queues_to_get);
94+
95+
// Create the VkDevice
6996
std::vector<const char*> requested_device_extensions{
7097
#ifdef VK_KHR_portability_subset
7198
VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
@@ -143,19 +170,42 @@ VkDevice create_logical_device(
143170
volkLoadDevice(handle);
144171
#endif /* USE_VULKAN_VOLK */
145172

146-
// Obtain handles for the created queues and initialize queue usage heuristic
173+
populate_queue_info(
174+
physical_device, handle, queues_to_get, queues, queue_usage);
147175

148-
for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
149-
VkQueue queue_handle = VK_NULL_HANDLE;
150-
VkQueueFlags flags =
151-
physical_device.queue_families.at(queue_idx.first).queueFlags;
152-
vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle);
153-
queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
154-
// Initial usage value
155-
queue_usage.push_back(0);
176+
return handle;
177+
}
178+
179+
bool test_linear_tiling_3d_image_support(VkDevice device) {
180+
// Test creating a 3D image with linear tiling to see if it is supported.
181+
// According to the Vulkan spec, linear tiling may not be supported for 3D
182+
// images.
183+
VkExtent3D image_extents{1u, 1u, 1u};
184+
const VkImageCreateInfo image_create_info{
185+
VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType
186+
nullptr, // pNext
187+
0u, // flags
188+
VK_IMAGE_TYPE_3D, // imageType
189+
VK_FORMAT_R32G32B32A32_SFLOAT, // format
190+
image_extents, // extents
191+
1u, // mipLevels
192+
1u, // arrayLayers
193+
VK_SAMPLE_COUNT_1_BIT, // samples
194+
VK_IMAGE_TILING_LINEAR, // tiling
195+
VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT, // usage
196+
VK_SHARING_MODE_EXCLUSIVE, // sharingMode
197+
0u, // queueFamilyIndexCount
198+
nullptr, // pQueueFamilyIndices
199+
VK_IMAGE_LAYOUT_UNDEFINED, // initialLayout
200+
};
201+
VkImage image = VK_NULL_HANDLE;
202+
VkResult res = vkCreateImage(device, &image_create_info, nullptr, &image);
203+
204+
if (res == VK_SUCCESS) {
205+
vkDestroyImage(device, image, nullptr);
156206
}
157207

158-
return handle;
208+
return res == VK_SUCCESS;
159209
}
160210

161211
} // namespace
@@ -186,37 +236,44 @@ Adapter::Adapter(
186236
compute_pipeline_cache_(device_.handle, cache_data_path),
187237
sampler_cache_(device_.handle),
188238
vma_(instance_, physical_device_.handle, device_.handle),
189-
linear_tiling_3d_enabled_{true} {
190-
// Test creating a 3D image with linear tiling to see if it is supported.
191-
// According to the Vulkan spec, linear tiling may not be supported for 3D
192-
// images.
193-
VkExtent3D image_extents{1u, 1u, 1u};
194-
const VkImageCreateInfo image_create_info{
195-
VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType
196-
nullptr, // pNext
197-
0u, // flags
198-
VK_IMAGE_TYPE_3D, // imageType
199-
VK_FORMAT_R32G32B32A32_SFLOAT, // format
200-
image_extents, // extents
201-
1u, // mipLevels
202-
1u, // arrayLayers
203-
VK_SAMPLE_COUNT_1_BIT, // samples
204-
VK_IMAGE_TILING_LINEAR, // tiling
205-
VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT, // usage
206-
VK_SHARING_MODE_EXCLUSIVE, // sharingMode
207-
0u, // queueFamilyIndexCount
208-
nullptr, // pQueueFamilyIndices
209-
VK_IMAGE_LAYOUT_UNDEFINED, // initialLayout
210-
};
211-
VkImage image = VK_NULL_HANDLE;
212-
VkResult res =
213-
vkCreateImage(device_.handle, &image_create_info, nullptr, &image);
214-
if (res != VK_SUCCESS) {
215-
linear_tiling_3d_enabled_ = false;
216-
} else {
217-
vkDestroyImage(device_.handle, image, nullptr);
239+
linear_tiling_3d_enabled_{
240+
test_linear_tiling_3d_image_support(device_.handle)},
241+
owns_device_{true} {}
242+
243+
Adapter::Adapter(
244+
VkInstance instance,
245+
VkPhysicalDevice physical_device,
246+
VkDevice logical_device,
247+
const uint32_t num_queues,
248+
const std::string& cache_data_path)
249+
: queue_usage_mutex_{},
250+
physical_device_(physical_device),
251+
queues_{},
252+
queue_usage_{},
253+
queue_mutexes_{},
254+
instance_(instance),
255+
device_(logical_device),
256+
shader_layout_cache_(device_.handle),
257+
shader_cache_(device_.handle),
258+
pipeline_layout_cache_(device_.handle),
259+
compute_pipeline_cache_(device_.handle, cache_data_path),
260+
sampler_cache_(device_.handle),
261+
vma_(instance_, physical_device_.handle, device_.handle),
262+
linear_tiling_3d_enabled_{
263+
test_linear_tiling_3d_image_support(device_.handle)},
264+
owns_device_{false} {
265+
std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
266+
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
267+
find_compute_queues(
268+
physical_device_, num_queues, queue_create_infos, queues_to_get);
269+
populate_queue_info(
270+
physical_device_, device_.handle, queues_to_get, queues_, queue_usage_);
271+
}
272+
273+
Adapter::~Adapter() {
274+
if (!owns_device_) {
275+
device_.handle = VK_NULL_HANDLE;
218276
}
219-
return;
220277
}
221278

222279
Adapter::Queue Adapter::request_queue() {

backends/vulkan/runtime/vk_api/Adapter.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,20 @@ class Adapter final {
5656
const uint32_t num_queues,
5757
const std::string& cache_data_path);
5858

59+
explicit Adapter(
60+
VkInstance instance,
61+
VkPhysicalDevice physical_device,
62+
VkDevice logical_device,
63+
const uint32_t num_queues,
64+
const std::string& cache_data_path);
65+
5966
Adapter(const Adapter&) = delete;
6067
Adapter& operator=(const Adapter&) = delete;
6168

6269
Adapter(Adapter&&) = delete;
6370
Adapter& operator=(Adapter&&) = delete;
6471

65-
~Adapter() = default;
72+
~Adapter();
6673

6774
struct Queue {
6875
uint32_t family_index;
@@ -94,6 +101,7 @@ class Adapter final {
94101
Allocator vma_;
95102
// Miscellaneous
96103
bool linear_tiling_3d_enabled_;
104+
bool owns_device_;
97105

98106
public:
99107
// Physical Device metadata

backends/vulkan/runtime/vk_api/Runtime.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
#include <iostream>
1515
#include <sstream>
1616

17+
#ifdef USE_VOLK_HEADER_ONLY
18+
// For volk.h, define this before including volk.h in exactly one CPP file.
19+
#define VOLK_IMPLEMENTATION
20+
#include <volk.h>
21+
#endif /* USE_VOLK_HEADER_ONLY */
22+
1723
namespace vkcompute {
1824
namespace vkapi {
1925

@@ -409,5 +415,35 @@ Runtime* runtime() {
409415
return p_runtime.get();
410416
}
411417

418+
std::unique_ptr<Adapter> init_external_adapter(
419+
const VkInstance instance,
420+
const VkPhysicalDevice physical_device,
421+
const VkDevice logical_device,
422+
const uint32_t num_queues,
423+
const std::string& cache_data_path) {
424+
if (instance == VK_NULL_HANDLE || physical_device == VK_NULL_HANDLE ||
425+
logical_device == VK_NULL_HANDLE) {
426+
return std::unique_ptr<Adapter>(nullptr);
427+
}
428+
429+
return std::make_unique<Adapter>(
430+
instance, physical_device, logical_device, num_queues, cache_data_path);
431+
}
432+
433+
Adapter* set_and_get_external_adapter(
434+
const VkInstance instance,
435+
const VkPhysicalDevice physical_device,
436+
const VkDevice logical_device) {
437+
static const std::unique_ptr<Adapter> p_external_adapter =
438+
init_external_adapter(
439+
instance,
440+
physical_device,
441+
logical_device,
442+
1,
443+
set_and_get_pipeline_cache_data_path(""));
444+
445+
return p_external_adapter.get();
446+
}
447+
412448
} // namespace vkapi
413449
} // namespace vkcompute

0 commit comments

Comments
 (0)