Skip to content

Commit 12544ff

Browse files
Kush Rastogifacebook-github-bot
authored andcommitted
Width Packing Mat1 input for Quantized Linear
Summary: Width packing mat1 input for Quantized Linear as ASR model provides channel-packed matrix while operator does not support channel-packed yet. Reviewed By: nathanaelsee Differential Revision: D64065606
1 parent 9a4d6ce commit 12544ff

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,46 +71,59 @@ void add_q_8w_linear_node(
7171
const ValueRef q_mat2_data,
7272
const ValueRef scales_data,
7373
const ValueRef out) {
74+
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
75+
ValueRef mat1_W_packed = mat1;
76+
ValueRef out_W_packed = out;
77+
if (!graph.is_buffer_storage(out) && graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
78+
// Ensure mat1 is width packed
79+
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
80+
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
81+
// Ensure out is packed correctly
82+
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
83+
}
7484
ValueRef q_mat2 =
7585
prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked);
7686
ValueRef scales =
7787
prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked);
7888

7989
std::string kernel_name = "q_8w_linear";
8090
kernel_name.reserve(kShaderNameReserve);
81-
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1));
91+
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
8292
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
83-
add_dtype_suffix(kernel_name, graph.dtype_of(out));
84-
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
93+
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
94+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
8595

8696
vkapi::ParamsBindList ubos({});
87-
if (graph.is_buffer_storage(out)) {
97+
if (graph.is_buffer_storage(out_W_packed)) {
8898
ubos.append(
89-
{graph.sizes_ubo(out),
90-
graph.strides_ubo(out),
91-
graph.numel_ubo(out),
92-
graph.sizes_ubo(mat1),
99+
{graph.sizes_ubo(out_W_packed),
100+
graph.strides_ubo(out_W_packed),
101+
graph.numel_ubo(out_W_packed),
102+
graph.sizes_ubo(mat1_W_packed),
93103
graph.strides_ubo(mat1),
94104
graph.strides_ubo(q_mat2),
95105
graph.strides_ubo(scales)});
96106
} else {
97-
ubos.append({graph.logical_limits_ubo(out), graph.sizes_ubo(mat1)});
107+
ubos.append({graph.logical_limits_ubo(out_W_packed), graph.sizes_ubo(mat1_W_packed)});
98108
}
99109

100110
graph.execute_nodes().emplace_back(new DispatchNode(
101111
graph,
102112
VK_KERNEL_FROM_STR(kernel_name),
103-
graph.create_global_wg_size(out),
104-
graph.create_local_wg_size(out),
113+
graph.create_global_wg_size(out_W_packed),
114+
graph.create_local_wg_size(out_W_packed),
105115
// Inputs and Outputs
106-
{{out, vkapi::MemoryAccessType::WRITE},
107-
{{mat1, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
116+
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
117+
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
108118
// Shader params buffers
109119
ubos,
110120
// Specialization Constants
111121
{},
112122
// Resizing Logic
113123
resize_qlinear_node));
124+
if (!graph.is_buffer_storage(out) && graph.packed_dim_of(out) != WHCN::kWidthDim) {
125+
viewFn(graph, {out_W_packed, graph.add_none(), out});
126+
}
114127
}
115128

116129
void weight_int8pack_mm(

0 commit comments

Comments
 (0)