Skip to content

Commit 14b3af5

Browse files
committed
Update
[ghstack-poisoned]
2 parents 960b99c + 270271b commit 14b3af5

File tree

4 files changed

+88
-29
lines changed

4 files changed

+88
-29
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,20 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
2424
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
2525
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
2626
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
27-
${layout_declare_ubo(4, "ivec3", "out_limits")}
28-
${layout_declare_ubo(5, "ivec4", "in_sizes")}
29-
${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
30-
${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
31-
${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
27+
28+
layout(push_constant) uniform restrict Block {
29+
ivec4 out_limits;
30+
ivec4 in_sizes;
31+
ivec2 kernel_size;
32+
ivec2 stride;
33+
ivec2 padding;
34+
ivec2 dilation;
35+
ivec2 overlay_region;
36+
int in_group_size;
37+
int dummy_padding;
38+
float out_min;
39+
float out_max;
40+
};
3241

3342
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3443

@@ -70,7 +79,7 @@ void main() {
7079

7180
// If the top left position is out of bounds, then this invocation will have
7281
// no work to do.
73-
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits))) {
82+
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits.xyz))) {
7483
return;
7584
}
7685

@@ -144,7 +153,7 @@ void main() {
144153

145154
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
146155
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
147-
if (all(lessThan(ivec3(pos, gpos.z), out_limits))) {
156+
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
148157
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
149158
}
150159
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -407,27 +407,68 @@ void add_conv2d_node(
407407
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
408408
}
409409

410-
graph.execute_nodes().emplace_back(new DispatchNode(
411-
graph,
412-
shader,
413-
wg_size,
414-
graph.create_local_wg_size(wg_size),
415-
// Inputs and Outputs
416-
{{out, vkapi::MemoryAccessType::WRITE},
417-
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
418-
// Shader params buffers
419-
{
420-
t_out->logical_limits_ubo(),
421-
t_in->sizes_ubo(),
422-
graph.create_params_buffer(kernel_params),
423-
graph.create_params_buffer(extra_params),
424-
graph.create_params_buffer(out_params),
425-
},
426-
// Specialization Constants
427-
{},
428-
// Resizing Logic
429-
resize_conv2d_node,
430-
{weight_data, stride, padding, dilation, transposed, output_padding}));
410+
if (method == Conv2dMethod::Pointwise) {
411+
const utils::ivec4 kernel_param_size_stride = {
412+
kernel_params.kernel_size[0],
413+
kernel_params.kernel_size[1],
414+
kernel_params.stride[0],
415+
kernel_params.stride[1]};
416+
417+
const utils::ivec4 kernel_param_pad_dial = {
418+
kernel_params.padding[0],
419+
kernel_params.padding[1],
420+
kernel_params.dilation[0],
421+
kernel_params.dilation[1]};
422+
423+
graph.execute_nodes().emplace_back(new DispatchNode(
424+
graph,
425+
shader,
426+
wg_size,
427+
graph.create_local_wg_size(wg_size),
428+
// Inputs and Outputs
429+
{{out, vkapi::MemoryAccessType::WRITE},
430+
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
431+
// Shader params buffers
432+
{},
433+
// Specialization Constants
434+
{},
435+
// Resizing Logic
436+
resize_conv2d_node,
437+
{weight_data, stride, padding, dilation, transposed, output_padding},
438+
{
439+
graph.logical_limits_pc_of(out),
440+
graph.sizes_pc_of(in),
441+
PushConstantDataInfo(
442+
&kernel_param_size_stride, sizeof(kernel_param_size_stride)),
443+
PushConstantDataInfo(
444+
&kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
445+
PushConstantDataInfo(
446+
&extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
447+
PushConstantDataInfo(&out_params, sizeof(out_params)),
448+
}));
449+
} else {
450+
graph.execute_nodes().emplace_back(new DispatchNode(
451+
graph,
452+
shader,
453+
wg_size,
454+
graph.create_local_wg_size(wg_size),
455+
// Inputs and Outputs
456+
{{out, vkapi::MemoryAccessType::WRITE},
457+
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
458+
// Shader params buffers
459+
{
460+
t_out->logical_limits_ubo(),
461+
t_in->sizes_ubo(),
462+
graph.create_params_buffer(kernel_params),
463+
graph.create_params_buffer(extra_params),
464+
graph.create_params_buffer(out_params),
465+
},
466+
// Specialization Constants
467+
{},
468+
// Resizing Logic
469+
resize_conv2d_node,
470+
{weight_data, stride, padding, dilation, transposed, output_padding}));
471+
}
431472
}
432473

433474
void add_conv1d_node(

codegen/tools/gen_selected_op_variants.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@
5959
"20": "Bits4x2",
6060
"21": "Bits8",
6161
"22": "Bits16",
62+
"23": "Float8_e5m2",
63+
"24": "Float8_e4m3fn",
64+
"25": "Float8_e5m2fnuz",
65+
"26": "Float8_e4m3fnuz",
66+
"27": "UInt16",
67+
"28": "UInt32",
68+
"29": "Uint64",
6269
}
6370

6471

@@ -84,7 +91,8 @@ def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None:
8491
dtype_set = set([x.split(";")[0] for x in tensor_meta])
8592
dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set])
8693
conditions = [
87-
"scalar_type == executorch::aten::ScalarType::" + x for x in dtype_list
94+
"scalar_type == executorch::aten::ScalarType::" + x
95+
for x in dtype_list
8896
]
8997
body_parts.append(
9098
ops_and_dtypes_template.substitute(

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ runtime.python_library(
117117
# "//executorch/extension/pybindings:aten_lib",
118118
# "//executorch/extension/pybindings:portable_lib",
119119
# "//executorch/extension/pybindings:portable_lib_plus_custom",
120+
"//executorch/devtools/backend_debug:delegation_info",
120121
"//executorch/devtools/etrecord:etrecord",
121122
"//executorch/util:memory_profiler",
122123
"//executorch/util:python_profiler",

0 commit comments

Comments
 (0)