Skip to content

Commit ea3b53d

Browse files
committed
Update base for Update on "[ET-VK] Using push constants for conv2d dw."
This diff is related to the use of push constants for convolutional dw (depthwise) in Executorch's Vulkan backend. This optimization improves memory usage. Differential Revision: [D68493849](https://our.internmc.facebook.com/intern/diff/D68493849/) [ghstack-poisoned]
2 parents 83ae1b9 + 7bc06d1 commit ea3b53d

File tree

25 files changed

+386
-93
lines changed

25 files changed

+386
-93
lines changed

backends/arm/operator_support/to_copy_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
125125
# Check dim_order (to_dim_order_copy)
126126
if "dim_order" in node.kwargs:
127127
dim_order = node.kwargs["dim_order"]
128+
# pyre-ignore[6]
128129
if dim_order != list(range(len(dim_order))):
129130
logger.info(
130131
f"Argument {dim_order=} is not supported for "

backends/cadence/aot/compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ExecutorchProgramManager,
3434
to_edge,
3535
)
36+
from executorch.exir.dialects._ops import ops as exir_ops
3637
from executorch.exir.pass_base import PassResult
3738
from executorch.exir.passes import ToOutVarPass
3839
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
@@ -186,14 +187,17 @@ def export_to_edge(
186187
edge_prog_manager = to_edge(
187188
expo_program,
188189
compile_config=EdgeCompileConfig(
189-
_skip_dim_order=True,
190190
# Allow specific non-core aten ops in the IR.
191191
_core_aten_ops_exception_list=[
192192
torch.ops.aten._native_batch_norm_legit_functional.default,
193193
torch.ops.aten.linear.default,
194194
torch.ops.aten.linalg_vector_norm.default,
195195
torch.ops.aten.unfold.default,
196196
torch.ops.aten.angle.default,
197+
# cadence replaced to_dim_order_copy with _to_copy for performance
198+
# skip _to_copy op to get around of dim order check
199+
# We should remove this op once cadence can support dim order
200+
exir_ops.edge.aten._to_copy.default,
197201
],
198202
),
199203
constant_methods=constant_methods,

backends/cadence/aot/replace_ops.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# pyre-unsafe
1313

14+
import copy
1415
import math
1516
from operator import neg
1617
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
@@ -35,7 +36,12 @@
3536
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3637
from executorch.exir.dialects._ops import ops as exir_ops
3738
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
39+
from executorch.exir.dim_order_utils import get_memory_format
3840
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
41+
from executorch.exir.passes.dim_order_ops_registry import (
42+
DimOrderOpsMap,
43+
MemoryFormatOpsMap,
44+
)
3945
from torch._subclasses import FakeTensor
4046
from torch.fx.node import Argument
4147

@@ -1799,6 +1805,72 @@ def call_operator(
17991805
)
18001806

18011807

1808+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1809+
class ReplaceToDimOrderCopyWithToCopyPass(ExportPass):
1810+
"""
1811+
dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass.
1812+
If the dim order is sequential, we don't need the extra work with strides and
1813+
can just use to_copy.
1814+
"""
1815+
1816+
def call_operator(
1817+
self,
1818+
op,
1819+
args: Tuple[Argument, ...],
1820+
kwargs: Dict[str, Argument],
1821+
meta: NodeMetadata,
1822+
) -> ProxyValue:
1823+
if op not in DimOrderOpsMap:
1824+
return super().call_operator(op, args, kwargs, meta)
1825+
1826+
# new kwargs with dim_order, and no memory_format for the new op
1827+
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
1828+
1829+
ndim = None
1830+
1831+
# can always get the shape, assuming rank is specialized
1832+
1833+
# pyre-ignore[16]: `None` has no attribute `to_tensor`
1834+
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
1835+
# pyre-ignore[16]: `None` has no attribute `to_tensor`
1836+
ndim = args[0].to_tensor().dim()
1837+
elif isinstance(args[0], torch.Tensor):
1838+
# pyre-ignore[16]: `None` has no attribute `dim`
1839+
ndim = args[0].dim()
1840+
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
1841+
# pyre-ignore[6]: Incompatible parameter type
1842+
ndim = len(args[0])
1843+
else:
1844+
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"
1845+
1846+
# get the "to" memory format for the EdgeOp
1847+
contiguous_dim_order = list(range(ndim))
1848+
dim_order = nkwargs.pop("dim_order", None)
1849+
1850+
# Cadence only supports contiguous memory format
1851+
assert (
1852+
dim_order is None
1853+
# pyre-ignore[6]: Incompatible parameter type
1854+
or len(dim_order) == 0
1855+
or dim_order == contiguous_dim_order
1856+
), "Expected dim order in congituous or prevserve memory format, but got {}".format(
1857+
dim_order
1858+
)
1859+
1860+
# bring back memory format
1861+
# pyre-ignore[6]: Incompatible parameter type
1862+
nkwargs["memory_format"] = get_memory_format(dim_order)
1863+
1864+
memory_format_op = MemoryFormatOpsMap[op]
1865+
1866+
return super().call_operator(
1867+
memory_format_op,
1868+
args,
1869+
nkwargs,
1870+
meta,
1871+
)
1872+
1873+
18021874
@register_cadence_pass(CadencePassAttribute(opt_level=0))
18031875
class ReplaceFullLikeWithFullPass(ExportPass):
18041876
"""
@@ -2108,4 +2180,5 @@ class CadenceReplaceOpsInGraph:
21082180
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
21092181
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
21102182
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2183+
ReplaceToDimOrderCopyWithToCopyPass,
21112184
]

backends/cadence/fusion_g3/operators/op_exp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
4949
out);
5050
#endif
5151

52-
if (out.scalar_type() == ScalarType::Float) {
53-
float* const out_data = out.mutable_data_ptr<float>();
54-
const float* const in_data = in.const_data_ptr<float>();
52+
if (in.scalar_type() == ScalarType::Float) {
53+
float* __restrict__ out_data = out.mutable_data_ptr<float>();
54+
const float* __restrict__ in_data = in.const_data_ptr<float>();
5555

5656
XT_KERNEL_CHECK(
5757
ctx, out, xa_nn_elm_exp_f32_f32, out_data, in_data, out.numel());
@@ -66,4 +66,4 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
6666
} // namespace native
6767
} // namespace G3
6868
} // namespace impl
69-
} // namespace cadence
69+
} // namespace cadence

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void main() {
4141
div_by_x % out_limits.y,
4242
div_by_x / out_limits.y);
4343

44-
if (any(greaterThanEqual(pos, out_limits))) {
44+
if (pos.z >= out_limits.z) {
4545
return;
4646
}
4747

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void main() {
6666
pos.y *= BATCH_SIZE_Y;
6767

6868
// do not process if top pixel does not fit within the output range
69-
if (any(greaterThanEqual(pos, out_limits))) {
69+
if (pos.z >= out_limits.z) {
7070
return;
7171
}
7272

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void main() {
4444
div_by_x % out_limits.y,
4545
div_by_x / out_limits.y);
4646

47-
if (any(greaterThanEqual(pos, out_limits))) {
47+
if (pos.z >= out_limits.z) {
4848
return;
4949
}
5050

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15-
#define TILE_SIZE ${TILE_SIZE}
15+
#define TILE_SIZE_X ${TILE_SIZE_X}
16+
#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+
#define LOCAL_WG_SIZE 64
1618

1719
#define op(X, A, B) ${OPERATOR}
1820

@@ -41,19 +43,19 @@ layout(push_constant) uniform restrict Block {
4143

4244
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4345

44-
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
45-
// 64 is the number of threads in the local wg
46-
$num_shared = 64 * TILE_SIZE * TILE_SIZE
47-
shared ivec2 pos_shared[${num_shared}];
46+
// For performance improvement, reduce register usage by caching positions in shared memory.
47+
// Offset index by 1 every 16 points to avoid bank access conflict.
48+
#define offset_pos_index(index) (index + ((index) >> 4))
49+
shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)];
4850

4951
/*
5052
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
5153
* output tile for pointwise convolution is more efficient because the kernel
5254
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
5355
*/
5456
void main() {
55-
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
56-
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
57+
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
58+
const uint shared_mem_stride = LOCAL_WG_SIZE;
5759

5860
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
5961
const ivec3 gpos = ivec3(
@@ -67,33 +69,32 @@ void main() {
6769
// +--------+--------+
6870
// | pos[2] | pos[3] |
6971
// +--------+--------+
70-
ivec2 pos[TILE_SIZE * TILE_SIZE];
71-
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
72-
for (int x = 0; x < TILE_SIZE; ++x) {
73-
pos[i] = ivec2(
74-
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
75-
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
72+
ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
73+
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
74+
for (int x = 0; x < TILE_SIZE_X; ++x) {
75+
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
76+
pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z);
7677
i++;
7778
}
7879
}
7980

8081
// If the top left position is out of bounds, then this invocation will have
8182
// no work to do.
82-
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits.xyz))) {
83+
if (gpos.z >= out_limits.z) {
8384
return;
8485
}
8586

8687
// Compute the index of the input texture that needs to be loaded for each
8788
// output position. Note that negative indices can be produced indicating that
8889
// the top-left element is in a region added by padding.
89-
ivec2 ipos[TILE_SIZE * TILE_SIZE];
90-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
90+
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
91+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
9192
ipos[i] = pos[i] * stride - padding;
9293
}
9394

94-
vec4 sum[TILE_SIZE * TILE_SIZE];
95+
vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
9596
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
96-
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
97+
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
9798
sum[i] = sum[0];
9899
}
99100

@@ -109,7 +110,7 @@ void main() {
109110
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
110111

111112
#pragma unroll
112-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
113+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
113114
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
114115
// For 2x2 tile size algorithm works as follows.
115116
// To explain the calculations below, the contents of one in_tex and the
@@ -151,10 +152,11 @@ void main() {
151152
}
152153
}
153154

154-
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
155-
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156-
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
157-
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
155+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
156+
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
157+
const ivec3 pos = pos_shared[offset_pos_index(index)];
158+
if (all(lessThan(pos, out_limits.xyz))) {
159+
imageStore(t_out, pos, op(sum[i], out_min, out_max));
158160
}
159161
}
160162
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ conv2d_pw:
99
OPERATOR: X
1010
NDIM: 3
1111
DTYPE: float
12-
TILE_SIZE: 2
12+
TILE_SIZE_X: 2
13+
TILE_SIZE_Y: 2
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: half

backends/xnnpack/test/ops/test_cat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,15 @@ def test_qs8_cat_gt_5(self):
187187
inputs.append(torch.randn(1, 2, 3))
188188
self._test_cat(self.Cat(), tuple(inputs), cat_num=num_inputs, quant=True)
189189

190+
def test_qs8_cat_with_empty_tensor(self):
191+
inputs = (
192+
torch.randn(0, 2, 3),
193+
torch.randn(1, 2, 3),
194+
torch.randn(3, 2, 3),
195+
torch.randn(0, 2, 3),
196+
)
197+
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
198+
190199
class CatNegativeDim(torch.nn.Module):
191200
def __init__(self):
192201
super().__init__()

0 commit comments

Comments
 (0)