Skip to content

Commit ba31c57

Browse files
committed
Update on "[ET-VK] Add coop shader for int8 linear"
Title says it all! ## Changes * Apply co-operative shader for vector * matrix computations. Differential Revision: [D73279548](https://our.internmc.facebook.com/intern/diff/D73279548/) [ghstack-poisoned]
2 parents 898b145 + 3277f6d commit ba31c57

File tree

8 files changed

+51
-11
lines changed

8 files changed

+51
-11
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
179179
return utils::kChannelsPacked;
180180
}
181181

182+
bool ComputeGraph::device_name_contains(const char* substr) {
183+
return context_->adapter_ptr()->device_name().find(substr) !=
184+
std::string::npos;
185+
}
186+
182187
void ComputeGraph::check_no_active_value_ptrs() {
183188
VK_CHECK_COND(
184189
values_in_use_ == 0,

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,15 @@ class ComputeGraph final {
443443
utils::GPUMemoryLayout suggested_memory_layout(
444444
const std::vector<int64_t>& sizes);
445445

446+
inline bool device_is_adreno() {
447+
return context_->adapter_ptr()->device_type() == vkapi::DeviceType::ADRENO;
448+
}
449+
const std::string& device_name() {
450+
return context()->adapter_ptr()->device_name();
451+
}
452+
453+
bool device_name_contains(const char* substr);
454+
446455
//
447456
// Graph Building
448457
//

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void main() {
6060
$if SCALES_STORAGE == "buffer":
6161
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
6262
$else:
63-
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
63+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
6464

6565
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
6666
partial_c[gid][wid][i] = VEC4_T(0.0);

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@ q_8w_linear_coop:
1010
IN_STORAGE: texture3d
1111
OUT_STORAGE: texture3d
1212
WEIGHT_STORAGE: texture2d
13-
SCALES_STORAGE: buffer
13+
SCALES_STORAGE: texture2d
1414
TILE_ROWS: 4
1515
generate_variant_forall:
1616
TILE_ROWS:
1717
- VALUE: 1
1818
SUFFIX: o4x1
1919
shader_variants:
20-
- NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_float
21-
- NAME: q_8w_linear_coop_buffer_buffer_texture2d_float
20+
- NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_texture2d_float
21+
- NAME: q_8w_linear_coop_buffer_buffer_texture2d_texture2d_float
2222
IN_STORAGE: buffer
2323
OUT_STORAGE: buffer
24-
- NAME: q_8w_linear_coop_buffer_buffer_buffer_float
24+
- NAME: q_8w_linear_coop_buffer_buffer_buffer_buffer_float
2525
IN_STORAGE: buffer
2626
OUT_STORAGE: buffer
2727
WEIGHT_STORAGE: buffer
28+
SCALES_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void main() {
5353
$if SCALES_STORAGE == "buffer":
5454
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
5555
$else:
56-
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
56+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
5757

5858
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
5959
c[i] = VEC4_T(0.0);

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ q_8w_linear_tiled:
1010
IN_STORAGE: texture3d
1111
OUT_STORAGE: texture3d
1212
WEIGHT_STORAGE: texture2d
13-
SCALES_STORAGE: buffer
13+
SCALES_STORAGE: texture2d
1414
TILE_ROWS: 4
1515
generate_variant_forall:
1616
TILE_ROWS:
@@ -21,11 +21,12 @@ q_8w_linear_tiled:
2121
- VALUE: 6
2222
SUFFIX: o4x6
2323
shader_variants:
24-
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_float
25-
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_float
24+
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float
25+
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float
2626
IN_STORAGE: buffer
2727
OUT_STORAGE: buffer
28-
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_float
28+
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_buffer_float
2929
IN_STORAGE: buffer
3030
OUT_STORAGE: buffer
3131
WEIGHT_STORAGE: buffer
32+
SCALES_STORAGE: buffer

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,20 @@ void add_q_8w_linear_tiled_node(
162162
ValueRef q_mat2 = prepack_standard_hw_transposed(
163163
graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked);
164164

165+
utils::StorageType scales_storage = utils::kTexture2D;
166+
if (N > max_extent) {
167+
scales_storage = utils::kBuffer;
168+
}
165169
ValueRef scales =
166-
prepack_standard(graph, scales_data, utils::kBuffer, utils::kWidthPacked);
170+
prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked);
167171

168172
std::string kernel_name =
169173
use_coop_algorithm ? "q_8w_linear_coop" : "q_8w_linear_tiled";
170174
kernel_name.reserve(kShaderNameReserve);
171175
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
172176
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
173177
add_storage_type_suffix(kernel_name, graph.storage_type_of(q_mat2));
178+
add_storage_type_suffix(kernel_name, graph.storage_type_of(scales));
174179
add_dtype_suffix(kernel_name, graph.dtype_of(out));
175180

176181
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
@@ -179,6 +184,9 @@ void add_q_8w_linear_tiled_node(
179184
if (M % 6 == 0) {
180185
kernel_name += "_o4x6";
181186
out_tile_nrows = 6;
187+
} else if (M % 4 == 0) {
188+
kernel_name += "_o4x4";
189+
out_tile_nrows = 4;
182190
} else if (M % 1 == 0) {
183191
kernel_name += "_o4x1";
184192
out_tile_nrows = 1;
@@ -255,6 +263,13 @@ bool can_use_tiled_impl(
255263
}
256264

257265
bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) {
266+
// Do not use coop algorithm for Adreno 702; manual experimentation shows that
267+
// it performs worse than the tiled algorithm.
268+
// TODO(ssjia): Determine a more robust heuristic to determine when the coop
269+
// algorithm should be used, instead of depending on specific device identity.
270+
if (graph.device_is_adreno() && graph.device_name_contains("702")) {
271+
return false;
272+
}
258273
// Check that the computation is vector * matrix
259274
return (graph.size_at<int>(-2, mat1) == 1);
260275
}

backends/vulkan/runtime/vk_api/Adapter.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ class Adapter final {
122122
return physical_device_.timestamp_period;
123123
}
124124

125+
// Device Identity
126+
inline const std::string& device_name() const {
127+
return physical_device_.device_name;
128+
}
129+
130+
inline vkapi::DeviceType device_type() const {
131+
return physical_device_.device_type;
132+
}
133+
125134
// Queue Management
126135

127136
Queue request_queue();

0 commit comments

Comments
 (0)