Skip to content

Commit 55b3621

Browse files
committed
Update base for Update on "[ET-VK] Add reshape functions for transformers related operators"
## Changes * Implement resize functions for several operators used in Transformers models ## Motivation Be able to support batched prefill for llama models. Differential Revision: [D75686049](https://our.internmc.facebook.com/intern/diff/D75686049/) [ghstack-poisoned]
1 parent a54e556 commit 55b3621

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,9 +1660,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
16601660
for (auto& new_sizes : new_sizes_list) {
16611661
graph.get_tensor(a.value)->virtual_resize(new_sizes);
16621662
graph.get_tensor(b.value)->virtual_resize(new_sizes);
1663-
graph.get_tensor(c)->virtual_resize(new_sizes);
16641663
graph.get_tensor(d.value)->virtual_resize(new_sizes);
1665-
graph.get_tensor(e)->virtual_resize(new_sizes);
1664+
graph.propagate_resize();
16661665

16671666
float val_a = new_sizes[1] + 4.0f;
16681667
float val_b = new_sizes[2] + 1.5f;

0 commit comments

Comments
 (0)