Skip to content

Commit 801ea2f

Browse files
committed
Update base for Update on "[ET-VK] Add support for binary symint ops"
## Changes * Add an implementation for binary operators which add symbolic integers. ## Motivation Support executing llama models with dynamic shapes. This operator shows up when exporting with dynamic shapes. Differential Revision: [D75238029](https://our.internmc.facebook.com/intern/diff/D75238029/) [ghstack-poisoned]
1 parent 3d041e6 commit 801ea2f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void check_args(
2929
const ValueRef in,
3030
const ValueRef permute_dims,
3131
const ValueRef out) {
32+
(void)permute_dims;
3233
VK_CHECK_COND(check_same_packed_dim(graph, in, out));
3334

3435
// This implementation doesn't not requires the input tensor to have the same
@@ -67,7 +68,7 @@ void resize_permute_node(
6768
in_sizes.size() > out_sizes.size() &&
6869
in_sizes.size() == permute_dims.size()) {
6970
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
70-
const int offset = in_sizes.size() - out_sizes.size();
71+
const size_t offset = in_sizes.size() - out_sizes.size();
7172
for (int i = 0; i < out_sizes.size(); i++) {
7273
const int64_t permute_dim = permute_dims.at(i + offset);
7374
new_out_sizes.at(i) = in_sizes.at(permute_dim);
@@ -79,7 +80,7 @@ void resize_permute_node(
7980
in_sizes.size() < out_sizes.size() &&
8081
out_sizes.size() == permute_dims.size()) {
8182
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
82-
const int offset = out_sizes.size() - in_sizes.size();
83+
const size_t offset = out_sizes.size() - in_sizes.size();
8384
for (int i = 0; i < out_sizes.size(); i++) {
8485
int64_t permute_dim = permute_dims.at(i) - offset;
8586
if (permute_dim >= 0) {
@@ -114,7 +115,8 @@ void add_permute_node(
114115
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
115116
seen[permute_dim] = true;
116117

117-
out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
118+
out_dims[(4u - out_ndim) + i] =
119+
utils::safe_downcast<int32_t>(permute_dim + (4 - out_ndim));
118120
}
119121
}
120122

0 commit comments

Comments
 (0)