@@ -192,6 +192,12 @@ class GraphBuilder {
192192 UIntVector dims_fb = tensor_fb->dims ();
193193 const std::vector<int64_t > dims_vector (dims_fb->cbegin (), dims_fb->cend ());
194194
195+ // For scalar tensors, add them as SymInts instead of tensors
196+ if (dtype == vkapi::kInt && utils::multiply_integers (dims_vector) == 1 ) {
197+ ref_mapping_[fb_id] = compute_graph_->add_symint (0 );
198+ return ;
199+ }
200+
195201 utils::GPUMemoryLayout memory_layout =
196202 tensor_fb->memory_layout () == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
197203 ? compute_graph_->suggested_memory_layout (dims_vector)
@@ -312,10 +318,16 @@ class GraphBuilder {
312318 add_value_to_graph (fb_id, value);
313319 }
314320
315- // Parse the inputs
321+ // Parse the inputs, which will be tensors most of the time but can also be
322+ // symints and tensorrefs (which will be the case if the original graph had)
323+ // mutable buffers.
316324 for (const uint32_t fb_id : *flatbuffer_->input_ids ()) {
317325 const ValueRef ref = get_fb_id_valueref (fb_id);
318- compute_graph_->set_input_tensor (ref);
326+ if (compute_graph_->val_is_tensor (ref)) {
327+ compute_graph_->set_input_tensor (ref);
328+ } else {
329+ compute_graph_->set_val_as_input (ref);
330+ }
319331 }
320332
321333 // Parse the operators
@@ -354,10 +366,15 @@ class GraphBuilder {
354366 }
355367 }
356368
357- // Parse the outputs
369+ // Parse the outputs, which will be mostly tensors. For some reason,
370+ // mutable buffers are shown to be returned in the fx.Graph but do not get
371+ // returned by the delegate; this may be an implementation detail of how the
372+ // executorch emitter handles mutable buffers.
358373 for (const uint32_t fb_id : *flatbuffer_->output_ids ()) {
359374 const ValueRef ref = get_fb_id_valueref (fb_id);
360- compute_graph_->set_output_tensor (ref);
375+ if (compute_graph_->val_is_tensor (ref)) {
376+ compute_graph_->set_output_tensor (ref);
377+ }
361378 }
362379 }
363380};
@@ -401,6 +418,26 @@ bool maybe_resize_input(
401418 return should_resize;
402419}
403420
421+ bool maybe_update_scalar_tensor (
422+ ComputeGraph* graph,
423+ const ValueRef ref,
424+ executorch::aten::Tensor& scalar_tensor_src) {
425+ const int32_t cur_val = graph->read_symint (ref);
426+ int32_t scalar_tensor_val = 0 ;
427+ exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type ();
428+ if (dtype == exec_aten::ScalarType::Int) {
429+ scalar_tensor_val = *scalar_tensor_src.const_data_ptr <int32_t >();
430+ } else if (dtype == exec_aten::ScalarType::Long) {
431+ scalar_tensor_val = int32_t (*scalar_tensor_src.const_data_ptr <int64_t >());
432+ }
433+ bool was_updated = false ;
434+ if (scalar_tensor_val != cur_val) {
435+ graph->set_symint (ref, scalar_tensor_val);
436+ was_updated = true ;
437+ }
438+ return was_updated;
439+ }
440+
404441void maybe_resize_output (
405442 ComputeGraph* graph,
406443 const size_t output_i,
@@ -487,7 +524,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
487524
488525 Error err = compileModel (processed->data (), compute_graph);
489526
490- // This backend does not need its processed data after compiling the model.
527+ // This backend does not need its processed data after compiling the
528+ // model.
491529 processed->Free ();
492530
493531 if (err != Error::Ok) {
@@ -508,13 +546,31 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
508546 const size_t num_inputs = compute_graph->inputs ().size ();
509547 bool should_propagate_resize = false ;
510548 for (size_t i = 0 ; i < num_inputs; i++) {
511- bool was_resized =
512- maybe_resize_input (compute_graph, i, args[i]->toTensor ());
513- should_propagate_resize = should_propagate_resize || was_resized;
514- compute_graph->copy_into_staging (
515- compute_graph->inputs ()[i].staging ,
516- args[i]->toTensor ().const_data_ptr (),
517- args[i]->toTensor ().numel ());
549+ const ValueRef iref = compute_graph->inputs ()[i].value ;
550+ if (compute_graph->val_is_tensor (iref)) {
551+ VK_CHECK_COND (args[i]->isTensor ());
552+ bool was_resized =
553+ maybe_resize_input (compute_graph, i, args[i]->toTensor ());
554+ should_propagate_resize = should_propagate_resize || was_resized;
555+ compute_graph->copy_into_staging (
556+ compute_graph->inputs ()[i].staging ,
557+ args[i]->toTensor ().const_data_ptr (),
558+ args[i]->toTensor ().numel ());
559+ } else if (compute_graph->val_is_symint (iref)) {
560+ VK_CHECK_COND (
561+ args[i]->isTensor (),
562+ " Cannot handle symint arg to graph that is not derived from a "
563+ " scalar tensor at the moment." );
564+ bool was_updated = maybe_update_scalar_tensor (
565+ compute_graph, iref, args[i]->toTensor ());
566+ // Since symint inputs may impact tensor's sizes, trigger a resize if
567+ // any symbolic integer shapes are updated.
568+ should_propagate_resize = should_propagate_resize || was_updated;
569+ } else {
570+ VK_THROW (
571+ " Could not handle input with type " ,
572+ compute_graph->get_val_type (iref));
573+ }
518574 }
519575
520576 if (should_propagate_resize) {
@@ -523,13 +579,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
523579 compute_graph->execute ();
524580
525581 for (size_t i = 0 ; i < compute_graph->outputs ().size (); i++) {
526- maybe_resize_output (compute_graph, i, args[num_inputs + i]->toTensor ());
527- // args holds inputs directly followed by outputs, so the i'th output
528- // for compute_graph corresponds to the (i + num_inputs)'th arg
529- compute_graph->copy_from_staging (
530- compute_graph->outputs ()[i].staging ,
531- args[num_inputs + i]->toTensor ().mutable_data_ptr (),
532- args[num_inputs + i]->toTensor ().numel ());
582+ const ValueRef oref = compute_graph->outputs ()[i].value ;
583+ if (compute_graph->val_is_tensor (oref)) {
584+ VK_CHECK_COND (args[i]->isTensor ());
585+ maybe_resize_output (compute_graph, i, args[num_inputs + i]->toTensor ());
586+ // args holds inputs directly followed by outputs, so the i'th output
587+ // for compute_graph corresponds to the (i + num_inputs)'th arg
588+ compute_graph->copy_from_staging (
589+ compute_graph->outputs ()[i].staging ,
590+ args[num_inputs + i]->toTensor ().mutable_data_ptr (),
591+ args[num_inputs + i]->toTensor ().numel ());
592+ } else {
593+ VK_THROW (
594+ " Could not handle output with type " ,
595+ compute_graph->get_val_type (oref));
596+ }
533597 }
534598
535599#ifdef ET_EVENT_TRACER_ENABLED
0 commit comments