Skip to content

Commit 227d6bf

Browse files
committed
Update on "Use c10 version of half/bfloat16 in executorch"
Accomplished by importing relevant files from c10 into executorch/runtime/core/portable_type/c10, and then using `using` in the top-level ExecuTorch headers. This approach should keep the ExecuTorch build hermetic for embedded use cases. In the future, we should add a CI job to ensure the c10 files stay identical to the PyTorch ones. Differential Revision: [D66106969](https://our.internmc.facebook.com/intern/diff/D66106969/) [ghstack-poisoned]
2 parents 6bbef4b + 3891c1c commit 227d6bf

File tree

12 files changed

+170
-42
lines changed

12 files changed

+170
-42
lines changed

.lintrunner.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ include_patterns = [
303303
# 'exir/**/*.py',
304304
# 'extension/**/*.py',
305305
'kernels/**/*.py',
306-
# 'profiler/**/*.py',
306+
'profiler/**/*.py',
307307
'runtime/**/*.py',
308308
'scripts/**/*.py',
309309
# 'test/**/*.py',
@@ -314,6 +314,7 @@ exclude_patterns = [
314314
'third-party/**',
315315
'**/third-party/**',
316316
'scripts/check_binary_dependencies.py',
317+
'profiler/test/test_profiler_e2e.py',
317318
]
318319
command = [
319320
'python',

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ follow_untyped_imports = True
4040
[mypy-executorch.kernels.*]
4141
follow_untyped_imports = True
4242

43+
[mypy-executorch.profiler.*]
44+
follow_untyped_imports = True
45+
4346
[mypy-executorch.runtime.*]
4447
follow_untyped_imports = True
4548

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3535
* output at a single output location.
3636
*/
3737
void main() {
38-
const ivec3 pos = ivec3(gl_GlobalInvocationID);
38+
const ivec3 pos = ivec3(
39+
gl_GlobalInvocationID.x % out_limits.x,
40+
(gl_GlobalInvocationID.x / out_limits.x) % out_limits.y,
41+
gl_GlobalInvocationID.x / (out_limits.x * out_limits.y));
3942

4043
if (any(greaterThanEqual(pos, out_limits))) {
4144
return;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#define TILE_SIZE ${TILE_SIZE}
1616

17+
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
18+
1719
#define op(X, A, B) ${OPERATOR}
1820

1921
#include "indexing_utils.h"
@@ -39,9 +41,20 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3941
* output at a single output location.
4042
*/
4143
void main() {
42-
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);
44+
// y divided up by batch size is used to determine 3d position
45+
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46+
const uint out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y;
47+
48+
u16vec3 pos = u16vec3(
49+
gl_GlobalInvocationID.x % out_limits.x,
50+
((gl_GlobalInvocationID.x / out_limits.x) % out_limits_y_scaled),
51+
gl_GlobalInvocationID.x / (out_limits.x * out_limits_y_scaled));
4352

44-
if (any(greaterThanEqual(pos, out_limits))) {
53+
// scale pos.y by batch size, because that's the top pixel to be processed
54+
pos.y *= uint16_t(BATCH_SIZE_Y);
55+
56+
// do not process if top pixel does not fit within the output range
57+
if (any(greaterThanEqual(u16vec3(pos.x, pos.y, pos.z), out_limits))) {
4558
return;
4659
}
4760

@@ -54,18 +67,47 @@ void main() {
5467
const u16vec2 start = ipos;
5568
const u16vec2 end = ipos + u16vec2(overlay_region.xy);
5669

57-
VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
70+
// sum outputs
71+
VEC4_T sum[BATCH_SIZE_Y];
72+
73+
sum[0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
74+
for (int i = 1; i < BATCH_SIZE_Y; i++) {
75+
sum[i] = sum[0];
76+
}
77+
78+
// array to store input texels
79+
VEC4_T in_texels[TILE_SIZE];
80+
81+
// array to store kernel data of previous y
82+
VEC4_T prev_kernel_line[TILE_SIZE];
83+
5884
uint16_t kx = uint16_t(0);
59-
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) {
85+
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) {
6086
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) {
61-
// The weight kernel was rearranged such that every NxN filter is
62-
// flattened to fit in one row. Each filter was then stacked on top of
63-
// each other vertically.
64-
const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
65-
sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum);
66-
kx++;
87+
in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
88+
}
89+
90+
// from 2nd iteration onwards accumulate dot product in 2nd sum
91+
// based on kernel line data fetched in previous iteration and input texel from this iteration
92+
if (i > uint16_t(0)) {
93+
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) {
94+
sum[1] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[1]);
95+
}
96+
}
97+
98+
// accumulate dot product in 1st sum only until tile size
99+
if (i < uint16_t(TILE_SIZE)) {
100+
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) {
101+
prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0);
102+
sum[0] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[0]);
103+
}
67104
}
68105
}
69106

70-
imageStore(t_out, pos, op(sum, out_min, out_max));
107+
for (int i = 0; i < BATCH_SIZE_Y; i++) {
108+
if (any(greaterThanEqual(u16vec3(pos.x, pos.y + i, pos.z), out_limits))) {
109+
continue;
110+
}
111+
imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max));
112+
}
71113
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ conv2d_dw_output_tile:
1010
NDIM: 3
1111
DTYPE: float
1212
TILE_SIZE: 3
13+
BATCH_SIZE_Y: 2
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: half

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,42 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434

3535
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
3636

37+
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
38+
shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
39+
3740
/*
3841
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
3942
* output tile for pointwise convolution is more efficient because the kernel
4043
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
4144
*/
4245
void main() {
43-
const u16vec3 gpos = u16vec3(gl_GlobalInvocationID);
46+
const uvec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
47+
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
48+
49+
const u16vec3 gpos = u16vec3(
50+
gl_GlobalInvocationID.x % out_limits_scaled.x,
51+
(gl_GlobalInvocationID.x / out_limits_scaled.x) % out_limits_scaled.y,
52+
gl_GlobalInvocationID.x / (out_limits_scaled.x * out_limits_scaled.y));
4453

4554
// Output position for TILE_SIZE = 2
4655
// +--------+--------+
4756
// | pos[0] | pos[1] |
4857
// +--------+--------+
4958
// | pos[2] | pos[3] |
5059
// +--------+--------+
51-
u16vec3 pos[TILE_SIZE * TILE_SIZE];
60+
u16vec2 pos[TILE_SIZE * TILE_SIZE];
5261
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
5362
for (int x = 0; x < TILE_SIZE; ++x) {
54-
pos[i] = u16vec3(
55-
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z);
63+
pos[i] = u16vec2(
64+
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
65+
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
5666
i++;
5767
}
5868
}
5969

6070
// If the top left position is out of bounds, then this invocation will have
6171
// no work to do.
62-
if (any(greaterThanEqual(pos[0], out_limits))) {
72+
if (any(greaterThanEqual(u16vec3(pos[0], gpos.z), out_limits))) {
6373
return;
6474
}
6575

@@ -68,7 +78,7 @@ void main() {
6878
// the top-left element is in a region added by padding.
6979
u16vec2 ipos[TILE_SIZE * TILE_SIZE];
7080
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
71-
ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding);
81+
ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding);
7282
}
7383

7484
vec4 sum[TILE_SIZE * TILE_SIZE];
@@ -133,8 +143,9 @@ void main() {
133143
}
134144

135145
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
136-
if (all(lessThan(pos[i], out_limits))) {
137-
imageStore(t_out, pos[i], op(sum[i], out_min, out_max));
146+
const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
147+
if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) {
148+
imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
138149
}
139150
}
140151
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ utils::uvec3 create_conv2d_global_wg_size(
296296
utils::div_up(image_extents[0u], 2u),
297297
utils::div_up(image_extents[1u], 2u),
298298
image_extents[2u]};
299+
} else if (method == Conv2dMethod::Depthwise) {
300+
const utils::uvec3 image_extents = graph.logical_limits_of(out);
301+
return {
302+
utils::div_up(image_extents[0u], 1u),
303+
utils::div_up(image_extents[1u], 2u),
304+
image_extents[2u]};
299305
} else {
300306
return graph.create_global_wg_size(out);
301307
}
@@ -370,11 +376,17 @@ void add_conv2d_node(
370376
weight_data,
371377
clamp_out);
372378

379+
utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);
380+
381+
if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
382+
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
383+
}
384+
373385
graph.execute_nodes().emplace_back(new DispatchNode(
374386
graph,
375387
shader,
376-
create_conv2d_global_wg_size(graph, method, out),
377-
graph.create_local_wg_size(out),
388+
wg_size,
389+
graph.create_local_wg_size(wg_size),
378390
// Inputs and Outputs
379391
{{out, vkapi::MemoryAccessType::WRITE},
380392
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},

docs/TARGETS

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
load("@fbcode_macros//build_defs:native_rules.bzl", "buck_filegroup", "buck_sh_test")
2+
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
3+
4+
oncall("pytorch_r2p")
5+
6+
python_binary(
7+
name = "sphinx",
8+
main_module = "sphinx.cmd.build",
9+
par_style = "xar",
10+
deps = [
11+
"//caffe2:torch",
12+
"//executorch/exir:lib",
13+
"//executorch/devtools:lib",
14+
"//executorch/exir/backend/test:backend_with_compiler_demo",
15+
"//executorch/exir/backend/test:op_partitioner_demo",
16+
"//executorch/devtools/bundled_program/serialize:lib",
17+
"fbsource//third-party/pypi/ipykernel:ipykernel",
18+
"fbsource//third-party/pypi/jupyter-client:jupyter-client",
19+
"fbsource//third-party/pypi/jupytext:jupytext",
20+
"fbsource//third-party/pypi/nbsphinx:nbsphinx",
21+
"fbsource//third-party/pypi/pytorch-sphinx-theme:pytorch-sphinx-theme",
22+
"fbsource//third-party/pypi/sphinx:sphinx",
23+
"fbsource//third-party/pypi/breathe:breathe",
24+
"fbsource//third-party/pypi/sphinx-copybutton:sphinx-copybutton",
25+
"fbsource//third-party/pypi/sphinx-design:sphinx-design",
26+
"fbsource//third-party/pypi/sphinx-gallery:sphinx-gallery",
27+
"fbsource//third-party/pypi/matplotlib:matplotlib",
28+
"fbsource//third-party/pypi/myst-parser:myst-parser",
29+
],
30+
)
31+
32+
buck_filegroup(
33+
name = "source",
34+
srcs = glob(["source/**/*"]),
35+
)
36+
37+
buck_sh_test(
38+
name = "doctest",
39+
args = [
40+
"-M",
41+
"doctest",
42+
"$(location :source)/source/",
43+
"/tmp/sphinxbuild",
44+
],
45+
test = ":sphinx",
46+
)

docs/source/conf.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
import os
4545
import sys
4646

47+
FBCODE = "fbcode" in os.getcwd()
48+
4749
extensions = [
4850
"breathe",
4951
"sphinx.ext.autodoc",
@@ -60,9 +62,13 @@
6062
"myst_parser",
6163
"sphinx_design",
6264
"sphinx_gallery.gen_gallery",
63-
"executorch_custom_versions",
6465
]
6566

67+
if not FBCODE:
68+
extensions += [
69+
"executorch_custom_versions",
70+
]
71+
6672
this_file_dir = os.path.abspath(os.path.dirname(__file__))
6773
doxygen_xml_dir = os.path.join(
6874
os.path.dirname(this_file_dir), # {repo_root}/docs/

examples/qualcomm/oss_scripts/llama3_2/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ python_binary(
99
name = "llama",
1010
srcs = ["llama.py"],
1111
main_function = "executorch.examples.qualcomm.oss_scripts.llama3_2.llama.main",
12+
preload_deps = [
13+
"//executorch/extension/llm/custom_ops:model_sharding_py",
14+
],
1215
deps = [
1316
"//executorch/examples/qualcomm/oss_scripts/llama2:static_llama",
1417
"//caffe2:torch",

0 commit comments

Comments
 (0)