Skip to content

Commit 36ab2ac

Browse files
committed
[ET-VK] Handle scalar tensor and mutable buffer inputs in Vulkan delegate runtime
## Context * Handle scalar tensor inputs by adding them to the graph as symbolic ints * Add support for symint inputs in the Vulkan delegate * Add type checking for Vulkan delegate inputs and outputs This is needed for Transformer models, which receive a an `input_pos` integer scalar tensor as an input. `input_pos` is used in KV cache updates and determines the sizes of the cache slices. ### Why are scalar tensors added as symint? Adding scalar tensors as symints makes more sense than adding them as real tensors, since symints are commonly used to inform tensor shapes. Adding scalar tensors as symints allow them to be easily accessible by the CPU at graph encoding and resizing time, as well as easily accesible by the GPU within compute shaders. Differential Revision: [D63979312](https://our.internmc.facebook.com/intern/diff/D63979312/) ghstack-source-id: 246578362 Pull Request resolved: #5930
1 parent a9586a5 commit 36ab2ac

File tree

2 files changed

+82
-18
lines changed

2 files changed

+82
-18
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ class GraphBuilder {
192192
UIntVector dims_fb = tensor_fb->dims();
193193
const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());
194194

195+
// For scalar tensors, add them as SymInts instead of tensors
196+
if (dtype == vkapi::kInt && utils::multiply_integers(dims_vector) == 1) {
197+
ref_mapping_[fb_id] = compute_graph_->add_symint(0);
198+
return;
199+
}
200+
195201
utils::GPUMemoryLayout memory_layout =
196202
tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
197203
? compute_graph_->suggested_memory_layout(dims_vector)
@@ -312,10 +318,16 @@ class GraphBuilder {
312318
add_value_to_graph(fb_id, value);
313319
}
314320

315-
// Parse the inputs
321+
// Parse the inputs, which will be tensors most of the time but can also be
322+
// symints and tensorrefs (which will be the case if the original graph had)
323+
// mutable buffers.
316324
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
317325
const ValueRef ref = get_fb_id_valueref(fb_id);
318-
compute_graph_->set_input_tensor(ref);
326+
if (compute_graph_->val_is_tensor(ref)) {
327+
compute_graph_->set_input_tensor(ref);
328+
} else {
329+
compute_graph_->set_val_as_input(ref);
330+
}
319331
}
320332

321333
// Parse the operators
@@ -354,10 +366,15 @@ class GraphBuilder {
354366
}
355367
}
356368

357-
// Parse the outputs
369+
// Parse the outputs, which will be mostly tensors. For some reason,
370+
// mutable buffers are shown to be returned in the fx.Graph but do not get
371+
// returned by the delegate; this may be an implementation detail of how the
372+
// executorch emitter handles mutable buffers.
358373
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
359374
const ValueRef ref = get_fb_id_valueref(fb_id);
360-
compute_graph_->set_output_tensor(ref);
375+
if (compute_graph_->val_is_tensor(ref)) {
376+
compute_graph_->set_output_tensor(ref);
377+
}
361378
}
362379
}
363380
};
@@ -508,13 +525,44 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
508525
const size_t num_inputs = compute_graph->inputs().size();
509526
bool should_propagate_resize = false;
510527
for (size_t i = 0; i < num_inputs; i++) {
511-
bool was_resized =
512-
maybe_resize_input(compute_graph, i, args[i]->toTensor());
513-
should_propagate_resize = should_propagate_resize || was_resized;
514-
compute_graph->copy_into_staging(
515-
compute_graph->inputs()[i].staging,
516-
args[i]->toTensor().const_data_ptr(),
517-
args[i]->toTensor().numel());
528+
const ValueRef iref = compute_graph->inputs()[i].value;
529+
if (compute_graph->val_is_tensor(iref)) {
530+
VK_CHECK_COND(args[i]->isTensor());
531+
bool was_resized =
532+
maybe_resize_input(compute_graph, i, args[i]->toTensor());
533+
should_propagate_resize = should_propagate_resize || was_resized;
534+
compute_graph->copy_into_staging(
535+
compute_graph->inputs()[i].staging,
536+
args[i]->toTensor().const_data_ptr(),
537+
args[i]->toTensor().numel());
538+
} else if (compute_graph->val_is_symint(iref)) {
539+
int32_t scalar_tensor_val = 0;
540+
const int32_t cur_val = compute_graph->read_symint(iref);
541+
if (args[i]->isTensor()) {
542+
exec_aten::Tensor& scalar_tensor_src = args[i]->toTensor();
543+
exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type();
544+
if (dtype == exec_aten::ScalarType::Int) {
545+
scalar_tensor_val = *scalar_tensor_src.const_data_ptr<int32_t>();
546+
} else if (dtype == exec_aten::ScalarType::Long) {
547+
scalar_tensor_val =
548+
int32_t(*scalar_tensor_src.const_data_ptr<int64_t>());
549+
}
550+
if (scalar_tensor_val != cur_val) {
551+
compute_graph->set_symint(iref, scalar_tensor_val);
552+
// Since symint inputs may impact tensor's sizes, trigger a resize
553+
// if any symbolic integer shapes are updated.
554+
should_propagate_resize = true;
555+
}
556+
} else {
557+
VK_THROW(
558+
"Cannot handle symint arg to graph that is not derived from a "
559+
"scalar tensor at the moment.");
560+
}
561+
} else {
562+
VK_THROW(
563+
"Could not handle input with type ",
564+
compute_graph->get_val_type(iref));
565+
}
518566
}
519567

520568
if (should_propagate_resize) {
@@ -523,13 +571,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
523571
compute_graph->execute();
524572

525573
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
526-
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
527-
// args holds inputs directly followed by outputs, so the i'th output
528-
// for compute_graph corresponds to the (i + num_inputs)'th arg
529-
compute_graph->copy_from_staging(
530-
compute_graph->outputs()[i].staging,
531-
args[num_inputs + i]->toTensor().mutable_data_ptr(),
532-
args[num_inputs + i]->toTensor().numel());
574+
const ValueRef oref = compute_graph->outputs()[i].value;
575+
if (compute_graph->val_is_tensor(oref)) {
576+
VK_CHECK_COND(args[i]->isTensor());
577+
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
578+
// args holds inputs directly followed by outputs, so the i'th output
579+
// for compute_graph corresponds to the (i + num_inputs)'th arg
580+
compute_graph->copy_from_staging(
581+
compute_graph->outputs()[i].staging,
582+
args[num_inputs + i]->toTensor().mutable_data_ptr(),
583+
args[num_inputs + i]->toTensor().numel());
584+
} else {
585+
VK_THROW(
586+
"Could not handle output with type ",
587+
compute_graph->get_val_type(oref));
588+
}
533589
}
534590

535591
#ifdef ET_EVENT_TRACER_ENABLED

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,14 @@ class ComputeGraph final {
555555

556556
int32_t read_symint(const ValueRef idx);
557557

558+
inline void set_val_as_input(const ValueRef idx) {
559+
inputs_.push_back({idx, kDummyValueRef});
560+
}
561+
562+
inline void set_val_as_output(const ValueRef idx) {
563+
outputs_.push_back({idx, kDummyValueRef});
564+
}
565+
558566
/*
559567
* Convenience function to add an input tensor along with its staging buffer
560568
*/

0 commit comments

Comments
 (0)