Skip to content

Commit 215bf66

Browse files
committed
[ET-VK] Handle scalar tensor and mutable buffer inputs in Vulkan delegate runtime
Pull Request resolved: #5930 ## Context * Handle scalar tensor inputs by adding them to the graph as symbolic ints * Add support for symint inputs in the Vulkan delegate * Add type checking for Vulkan delegate inputs and outputs This is needed for Transformer models, which receive a an `input_pos` integer scalar tensor as an input. `input_pos` is used in KV cache updates and determines the sizes of the cache slices. Additionally, mutable buffer inputs/outputs, which appear as `TensorRef` to the Vulkan graph, are handled as well by ignoring them when copying outputs. More details in the comments. ### Why are scalar tensors added as symint? Adding scalar tensors as symints makes more sense than adding them as real tensors, since symints are commonly used to inform tensor shapes. Adding scalar tensors as symints allow them to be easily accessible by the CPU at graph encoding and resizing time, as well as easily accesible by the GPU within compute shaders. ghstack-source-id: 246627218 Differential Revision: [D63979312](https://our.internmc.facebook.com/intern/diff/D63979312/)
1 parent 29bf7df commit 215bf66

File tree

2 files changed

+91
-19
lines changed

2 files changed

+91
-19
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ class GraphBuilder {
192192
UIntVector dims_fb = tensor_fb->dims();
193193
const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());
194194

195+
// For scalar tensors, add them as SymInts instead of tensors
196+
if (dtype == vkapi::kInt && utils::multiply_integers(dims_vector) == 1) {
197+
ref_mapping_[fb_id] = compute_graph_->add_symint(0);
198+
return;
199+
}
200+
195201
utils::GPUMemoryLayout memory_layout =
196202
tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
197203
? compute_graph_->suggested_memory_layout(dims_vector)
@@ -312,10 +318,16 @@ class GraphBuilder {
312318
add_value_to_graph(fb_id, value);
313319
}
314320

315-
// Parse the inputs
321+
// Parse the inputs, which will be tensors most of the time but can also be
322+
// symints and tensorrefs (which will be the case if the original graph had)
323+
// mutable buffers.
316324
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
317325
const ValueRef ref = get_fb_id_valueref(fb_id);
318-
compute_graph_->set_input_tensor(ref);
326+
if (compute_graph_->val_is_tensor(ref)) {
327+
compute_graph_->set_input_tensor(ref);
328+
} else {
329+
compute_graph_->set_val_as_input(ref);
330+
}
319331
}
320332

321333
// Parse the operators
@@ -354,10 +366,15 @@ class GraphBuilder {
354366
}
355367
}
356368

357-
// Parse the outputs
369+
// Parse the outputs, which will be mostly tensors. For some reason,
370+
// mutable buffers are shown to be returned in the fx.Graph but do not get
371+
// returned by the delegate; this may be an implementation detail of how the
372+
// executorch emitter handles mutable buffers.
358373
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
359374
const ValueRef ref = get_fb_id_valueref(fb_id);
360-
compute_graph_->set_output_tensor(ref);
375+
if (compute_graph_->val_is_tensor(ref)) {
376+
compute_graph_->set_output_tensor(ref);
377+
}
361378
}
362379
}
363380
};
@@ -401,6 +418,26 @@ bool maybe_resize_input(
401418
return should_resize;
402419
}
403420

421+
bool maybe_update_scalar_tensor(
422+
ComputeGraph* graph,
423+
const ValueRef ref,
424+
executorch::aten::Tensor& scalar_tensor_src) {
425+
const int32_t cur_val = graph->read_symint(ref);
426+
int32_t scalar_tensor_val = 0;
427+
exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type();
428+
if (dtype == exec_aten::ScalarType::Int) {
429+
scalar_tensor_val = *scalar_tensor_src.const_data_ptr<int32_t>();
430+
} else if (dtype == exec_aten::ScalarType::Long) {
431+
scalar_tensor_val = int32_t(*scalar_tensor_src.const_data_ptr<int64_t>());
432+
}
433+
bool was_updated = false;
434+
if (scalar_tensor_val != cur_val) {
435+
graph->set_symint(ref, scalar_tensor_val);
436+
was_updated = true;
437+
}
438+
return was_updated;
439+
}
440+
404441
void maybe_resize_output(
405442
ComputeGraph* graph,
406443
const size_t output_i,
@@ -487,7 +524,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
487524

488525
Error err = compileModel(processed->data(), compute_graph);
489526

490-
// This backend does not need its processed data after compiling the model.
527+
// This backend does not need its processed data after compiling the
528+
// model.
491529
processed->Free();
492530

493531
if (err != Error::Ok) {
@@ -508,13 +546,31 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
508546
const size_t num_inputs = compute_graph->inputs().size();
509547
bool should_propagate_resize = false;
510548
for (size_t i = 0; i < num_inputs; i++) {
511-
bool was_resized =
512-
maybe_resize_input(compute_graph, i, args[i]->toTensor());
513-
should_propagate_resize = should_propagate_resize || was_resized;
514-
compute_graph->copy_into_staging(
515-
compute_graph->inputs()[i].staging,
516-
args[i]->toTensor().const_data_ptr(),
517-
args[i]->toTensor().numel());
549+
const ValueRef iref = compute_graph->inputs()[i].value;
550+
if (compute_graph->val_is_tensor(iref)) {
551+
VK_CHECK_COND(args[i]->isTensor());
552+
bool was_resized =
553+
maybe_resize_input(compute_graph, i, args[i]->toTensor());
554+
should_propagate_resize = should_propagate_resize || was_resized;
555+
compute_graph->copy_into_staging(
556+
compute_graph->inputs()[i].staging,
557+
args[i]->toTensor().const_data_ptr(),
558+
args[i]->toTensor().numel());
559+
} else if (compute_graph->val_is_symint(iref)) {
560+
VK_CHECK_COND(
561+
args[i]->isTensor(),
562+
"Cannot handle symint arg to graph that is not derived from a "
563+
"scalar tensor at the moment.");
564+
bool was_updated = maybe_update_scalar_tensor(
565+
compute_graph, iref, args[i]->toTensor());
566+
// Since symint inputs may impact tensor's sizes, trigger a resize if
567+
// any symbolic integer shapes are updated.
568+
should_propagate_resize = should_propagate_resize || was_updated;
569+
} else {
570+
VK_THROW(
571+
"Could not handle input with type ",
572+
compute_graph->get_val_type(iref));
573+
}
518574
}
519575

520576
if (should_propagate_resize) {
@@ -523,13 +579,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
523579
compute_graph->execute();
524580

525581
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
526-
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
527-
// args holds inputs directly followed by outputs, so the i'th output
528-
// for compute_graph corresponds to the (i + num_inputs)'th arg
529-
compute_graph->copy_from_staging(
530-
compute_graph->outputs()[i].staging,
531-
args[num_inputs + i]->toTensor().mutable_data_ptr(),
532-
args[num_inputs + i]->toTensor().numel());
582+
const ValueRef oref = compute_graph->outputs()[i].value;
583+
if (compute_graph->val_is_tensor(oref)) {
584+
VK_CHECK_COND(args[i]->isTensor());
585+
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
586+
// args holds inputs directly followed by outputs, so the i'th output
587+
// for compute_graph corresponds to the (i + num_inputs)'th arg
588+
compute_graph->copy_from_staging(
589+
compute_graph->outputs()[i].staging,
590+
args[num_inputs + i]->toTensor().mutable_data_ptr(),
591+
args[num_inputs + i]->toTensor().numel());
592+
} else {
593+
VK_THROW(
594+
"Could not handle output with type ",
595+
compute_graph->get_val_type(oref));
596+
}
533597
}
534598

535599
#ifdef ET_EVENT_TRACER_ENABLED

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,14 @@ class ComputeGraph final {
555555

556556
int32_t read_symint(const ValueRef idx);
557557

558+
inline void set_val_as_input(const ValueRef idx) {
559+
inputs_.push_back({idx, kDummyValueRef});
560+
}
561+
562+
inline void set_val_as_output(const ValueRef idx) {
563+
outputs_.push_back({idx, kDummyValueRef});
564+
}
565+
558566
/*
559567
* Convenience function to add an input tensor along with its staging buffer
560568
*/

0 commit comments

Comments
 (0)