Skip to content

Commit ff3822c

Browse files
committed
[ET-VK] Introduce generalized shaders for transfer ops and use it for select and slice
## Changes * Introduce `transfer_buffer.glsl` and `transfer_texture.glsl`, and `Transfer.cpp` which generalizes shaders where each element of the output is copied from a unique element of the input. * Update `Slice.cpp` and `Select.cpp` to use `Transfer.cpp` * Remove old implementations of slice and select ## Motivation With this new implementation, the op can now support both buffers and textures of any packing. There are also benefits of code consolidation. Differential Revision: [D75686050](https://our.internmc.facebook.com/intern/diff/D75686050/) ghstack-source-id: 287193345 Pull Request resolved: #11255
1 parent 60167da commit ff3822c

32 files changed

+812
-873
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,14 +492,24 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
492492
const ValueRef idx) {
493493
if (values_.at(idx).isInt()) {
494494
const int32_t val = extract_scalar<int32_t>(idx);
495-
create_params_buffer(val);
495+
return create_params_buffer(val);
496496
} else if (values_.at(idx).isSymInt()) {
497497
SymIntPtr symint = get_symint(idx);
498498
return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
499499
}
500500
VK_THROW("Cannot create a int param buffer for the given value");
501501
}
502502

503+
vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
504+
const ValueRef idx,
505+
const int32_t default_val) {
506+
if (values_.at(idx).isNone()) {
507+
return create_params_buffer(default_val);
508+
} else {
509+
return get_or_create_int_param_buffer(idx);
510+
}
511+
}
512+
503513
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
504514
get_symint(idx)->set(val);
505515
}
@@ -692,6 +702,12 @@ void ComputeGraph::resize_input(
692702
get_tensor(io_val.value)->virtual_resize(new_sizes);
693703
}
694704

705+
void ComputeGraph::virtual_resize(
706+
const ValueRef idx,
707+
const std::vector<int64_t>& new_sizes) {
708+
get_tensor(idx)->virtual_resize(new_sizes);
709+
}
710+
695711
void ComputeGraph::propagate_resize() {
696712
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
697713
node->trigger_resize(this);

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,19 @@ class ComputeGraph final {
397397
std::optional<T> extract_optional_scalar(const ValueRef idx) {
398398
if (val_is_none(idx)) {
399399
return ::std::nullopt;
400+
} else if (val_is_symint(idx)) {
401+
return utils::safe_downcast<T>(read_symint(idx));
402+
} else {
403+
return extract_scalar<T>(idx);
404+
}
405+
}
406+
407+
template <typename T>
408+
T extract_optional_scalar(const ValueRef idx, const T default_val) {
409+
if (val_is_none(idx)) {
410+
return default_val;
411+
} else if (val_is_symint(idx)) {
412+
return utils::safe_downcast<T>(read_symint(idx));
400413
} else {
401414
return extract_scalar<T>(idx);
402415
}
@@ -608,6 +621,10 @@ class ComputeGraph final {
608621
*/
609622
vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);
610623

624+
vkapi::BufferBindInfo get_or_create_int_param_buffer(
625+
const ValueRef idx,
626+
const int32_t default_value);
627+
611628
void set_symint(const ValueRef idx, const int32_t val);
612629

613630
int32_t read_symint(const ValueRef idx);
@@ -752,6 +769,9 @@ class ComputeGraph final {
752769
//
753770

754771
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
772+
void virtual_resize(
773+
const ValueRef idx,
774+
const std::vector<int64_t>& new_sizes);
755775
void propagate_resize();
756776

757777
//
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef SELECT_GLSLH
10+
#define SELECT_GLSLH
11+
12+
/*
13+
* Enable the fast path if a texel loaded from the input texture can be used as
14+
* is to store to the output texture. The following conditions must be met:
15+
*
16+
* 1. The input and output textures have the same packed dimension.
17+
* 2. The selected_dim must not be the packed dimension of the input.
18+
* 3. The packed dimension of the input must "map" to the packed dimension of
19+
* the output. This occurs if selected_dim is greater than the packed dimension
20+
* of the input.
21+
*/
22+
bool can_use_fast_path() {
23+
if (out_packed_dim != in_packed_dim) {
24+
return false;
25+
}
26+
if (selected_dim <= in_packed_dim) {
27+
return false;
28+
}
29+
return true;
30+
}
31+
32+
/*
33+
* Given an output tensor index, return the corresponding input tensor index for
34+
* the select operator. This is done by "inserting" the select index at the
35+
* selected_dim in the input tensor index.
36+
*
37+
* A simple example is (note all tensor index are in WHCN order):
38+
* out_tidx = [7, 5, 9]
39+
* selected_dim = 2
40+
* index = 3
41+
* in_tidx = [7, 3, 5, 9]
42+
*
43+
* This function assumes that the following variables are defined in the layout:
44+
* - in_sizes
45+
* - selected_dim
46+
* - index
47+
*/
48+
ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) {
49+
ivec4 in_tidx = ivec4(0);
50+
51+
int adjusted_index = index;
52+
if (index < 0) {
53+
adjusted_index = index + in_sizes[selected_dim];
54+
}
55+
56+
// Handle different dimensions for selection
57+
if (selected_dim == 0) {
58+
// Select from width dimension
59+
in_tidx = ivec4(adjusted_index, out_tidx.x, out_tidx.y, out_tidx.z);
60+
} else if (selected_dim == 1) {
61+
// Select from height dimension
62+
in_tidx = ivec4(out_tidx.x, adjusted_index, out_tidx.y, out_tidx.z);
63+
} else if (selected_dim == 2) {
64+
// Select from channel dimension
65+
in_tidx = ivec4(out_tidx.x, out_tidx.y, adjusted_index, out_tidx.z);
66+
} else if (selected_dim == 3) {
67+
// Select from batch dimension
68+
in_tidx = ivec4(out_tidx.x, out_tidx.y, out_tidx.z, adjusted_index);
69+
}
70+
71+
return in_tidx;
72+
}
73+
74+
#endif // SELECT_GLSLH

backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl

Lines changed: 0 additions & 52 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl

Lines changed: 0 additions & 50 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml

Lines changed: 0 additions & 10 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl

Lines changed: 0 additions & 65 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)