@@ -192,6 +192,12 @@ class GraphBuilder {
192192 UIntVector dims_fb = tensor_fb->dims ();
193193 const std::vector<int64_t > dims_vector (dims_fb->cbegin (), dims_fb->cend ());
194194
195+ // For scalar tensors, add them as SymInts instead of tensors
196+ if (dtype == vkapi::kInt && utils::multiply_integers (dims_vector) == 1 ) {
197+ ref_mapping_[fb_id] = compute_graph_->add_symint (0 );
198+ return ;
199+ }
200+
195201 utils::GPUMemoryLayout memory_layout =
196202 tensor_fb->memory_layout () == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
197203 ? compute_graph_->suggested_memory_layout (dims_vector)
@@ -312,10 +318,16 @@ class GraphBuilder {
312318 add_value_to_graph (fb_id, value);
313319 }
314320
315- // Parse the inputs
321+ // Parse the inputs, which will be tensors most of the time but can also be
322+ // symints and tensorrefs (which will be the case if the original graph had)
323+ // mutable buffers.
316324 for (const uint32_t fb_id : *flatbuffer_->input_ids ()) {
317325 const ValueRef ref = get_fb_id_valueref (fb_id);
318- compute_graph_->set_input_tensor (ref);
326+ if (compute_graph_->val_is_tensor (ref)) {
327+ compute_graph_->set_input_tensor (ref);
328+ } else {
329+ compute_graph_->set_val_as_input (ref);
330+ }
319331 }
320332
321333 // Parse the operators
@@ -354,10 +366,15 @@ class GraphBuilder {
354366 }
355367 }
356368
357- // Parse the outputs
369+ // Parse the outputs, which will be mostly tensors. For some reason,
370+ // mutable buffers are shown to be returned in the fx.Graph but do not get
371+ // returned by the delegate; this may be an implementation detail of how the
372+ // executorch emitter handles mutable buffers.
358373 for (const uint32_t fb_id : *flatbuffer_->output_ids ()) {
359374 const ValueRef ref = get_fb_id_valueref (fb_id);
360- compute_graph_->set_output_tensor (ref);
375+ if (compute_graph_->val_is_tensor (ref)) {
376+ compute_graph_->set_output_tensor (ref);
377+ }
361378 }
362379 }
363380};
@@ -508,13 +525,44 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
508525 const size_t num_inputs = compute_graph->inputs ().size ();
509526 bool should_propagate_resize = false ;
510527 for (size_t i = 0 ; i < num_inputs; i++) {
511- bool was_resized =
512- maybe_resize_input (compute_graph, i, args[i]->toTensor ());
513- should_propagate_resize = should_propagate_resize || was_resized;
514- compute_graph->copy_into_staging (
515- compute_graph->inputs ()[i].staging ,
516- args[i]->toTensor ().const_data_ptr (),
517- args[i]->toTensor ().numel ());
528+ const ValueRef iref = compute_graph->inputs ()[i].value ;
529+ if (compute_graph->val_is_tensor (iref)) {
530+ VK_CHECK_COND (args[i]->isTensor ());
531+ bool was_resized =
532+ maybe_resize_input (compute_graph, i, args[i]->toTensor ());
533+ should_propagate_resize = should_propagate_resize || was_resized;
534+ compute_graph->copy_into_staging (
535+ compute_graph->inputs ()[i].staging ,
536+ args[i]->toTensor ().const_data_ptr (),
537+ args[i]->toTensor ().numel ());
538+ } else if (compute_graph->val_is_symint (iref)) {
539+ int32_t scalar_tensor_val = 0 ;
540+ const int32_t cur_val = compute_graph->read_symint (iref);
541+ if (args[i]->isTensor ()) {
542+ exec_aten::Tensor& scalar_tensor_src = args[i]->toTensor ();
543+ exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type ();
544+ if (dtype == exec_aten::ScalarType::Int) {
545+ scalar_tensor_val = *scalar_tensor_src.const_data_ptr <int32_t >();
546+ } else if (dtype == exec_aten::ScalarType::Long) {
547+ scalar_tensor_val =
548+ int32_t (*scalar_tensor_src.const_data_ptr <int64_t >());
549+ }
550+ if (scalar_tensor_val != cur_val) {
551+ compute_graph->set_symint (iref, scalar_tensor_val);
552+ // Since symint inputs may impact tensor's sizes, trigger a resize
553+ // if any symbolic integer shapes are updated.
554+ should_propagate_resize = true ;
555+ }
556+ } else {
557+ VK_THROW (
558+ " Cannot handle symint arg to graph that is not derived from a "
559+ " scalar tensor at the moment." );
560+ }
561+ } else {
562+ VK_THROW (
563+ " Could not handle input with type " ,
564+ compute_graph->get_val_type (iref));
565+ }
518566 }
519567
520568 if (should_propagate_resize) {
@@ -523,13 +571,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
523571 compute_graph->execute ();
524572
525573 for (size_t i = 0 ; i < compute_graph->outputs ().size (); i++) {
526- maybe_resize_output (compute_graph, i, args[num_inputs + i]->toTensor ());
527- // args holds inputs directly followed by outputs, so the i'th output
528- // for compute_graph corresponds to the (i + num_inputs)'th arg
529- compute_graph->copy_from_staging (
530- compute_graph->outputs ()[i].staging ,
531- args[num_inputs + i]->toTensor ().mutable_data_ptr (),
532- args[num_inputs + i]->toTensor ().numel ());
574+ const ValueRef oref = compute_graph->outputs ()[i].value ;
575+ if (compute_graph->val_is_tensor (oref)) {
576+ VK_CHECK_COND (args[i]->isTensor ());
577+ maybe_resize_output (compute_graph, i, args[num_inputs + i]->toTensor ());
578+ // args holds inputs directly followed by outputs, so the i'th output
579+ // for compute_graph corresponds to the (i + num_inputs)'th arg
580+ compute_graph->copy_from_staging (
581+ compute_graph->outputs ()[i].staging ,
582+ args[num_inputs + i]->toTensor ().mutable_data_ptr (),
583+ args[num_inputs + i]->toTensor ().numel ());
584+ } else {
585+ VK_THROW (
586+ " Could not handle output with type " ,
587+ compute_graph->get_val_type (oref));
588+ }
533589 }
534590
535591#ifdef ET_EVENT_TRACER_ENABLED
0 commit comments