Skip to content

Commit 3172fca

Browse files
authored
[ET-VK][ez] Fix embedding resize function
Differential Revision: D85895557 Pull Request resolved: #15476
1 parent 0f9cef4 commit 3172fca

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ void resize_embedding_node(
4242
const std::vector<ValueRef>& resize_args) {
4343
(void)resize_args;
4444
const ValueRef out = args.at(0).refs.at(0);
45-
const ValueRef weight = args.at(1).refs.at(0);
46-
const ValueRef indices = args.at(1).refs.at(1);
45+
const ValueRef indices = args.at(1).refs.at(0);
46+
const ValueRef weight = args.at(1).refs.at(1);
4747

48-
const std::vector<int64_t> weight_sizes = graph->sizes_of(weight);
4948
const std::vector<int64_t> indices_sizes = graph->sizes_of(indices);
49+
const std::vector<int64_t> weight_sizes = graph->sizes_of(weight);
5050

5151
// Output shape is indices.shape + [embedding_dim]
5252
// where embedding_dim is the last dimension of weight

0 commit comments

Comments
 (0)