Skip to content

Commit 5dac4a5

Browse files
author
Github Executorch
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents e44c69c + 9911992 commit 5dac4a5

File tree

22 files changed

+292
-115
lines changed

22 files changed

+292
-115
lines changed

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def export_to_executorch_gen_etrecord(
264264
alloc_graph_output: bool = True,
265265
memory_config: Optional[MemoryConfig] = None,
266266
dump_graphs: bool = False,
267+
mem_alignment: int = 1,
267268
) -> ExecutorchProgramManager:
268269
cadence_passes = get_cadence_passes(opt_level)
269270
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
@@ -290,6 +291,7 @@ def export_to_executorch_gen_etrecord(
290291
mem_algo=mem_algo,
291292
alloc_graph_input=alloc_graph_input,
292293
alloc_graph_output=alloc_graph_output,
294+
mem_alignment=mem_alignment,
293295
)
294296

295297
# Get executorch program after Cadence specific passes

backends/cadence/aot/memory_planning.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import collections
1010
import itertools
1111
import logging
12+
import math
1213
import typing
1314
from functools import partial
1415
from typing import Iterable, List, Optional, Tuple
@@ -39,6 +40,10 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
3940
return memory_config.memory_sizes[exir_id - 1]
4041

4142

43+
def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
44+
return int(math.ceil(pre_aligned_offset / alignment) * alignment)
45+
46+
4247
def collect_specs_from_graph_module(
4348
graph_module: torch.fx.GraphModule,
4449
alloc_graph_input: bool,
@@ -95,9 +100,9 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
95100
return None
96101

97102
def memory_available(spec: TensorSpec) -> bool:
98-
return spec.mem_offset + spec.allocated_memory <= get_size(
99-
memory_config, spec.mem_id
100-
)
103+
return get_aligned_offset(
104+
spec.mem_offset + spec.allocated_memory, alignment
105+
) <= get_size(memory_config, spec.mem_id)
101106

102107
# Iterate over all the specs in sorted order
103108
for spec in sorted(
@@ -116,7 +121,9 @@ def memory_available(spec: TensorSpec) -> bool:
116121
continue
117122
spec.mem_offset = 0
118123
while memory_available(spec) and (overlapped := overlap(spec)):
119-
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
124+
spec.mem_offset = get_aligned_offset(
125+
overlapped.mem_offset + overlapped.allocated_memory, alignment
126+
)
120127
if memory_available(spec):
121128
allocated_buffers[spec.mem_id].append(spec)
122129
bufsizes[spec.mem_id] = max(
@@ -202,13 +209,16 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
202209
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
203210
# to the naive one, reusing 0 tensor. The paper may have a typo here.
204211
prev_offset = max(
205-
allocated_spec.mem_offset + allocated_spec.allocated_memory,
212+
get_aligned_offset(
213+
allocated_spec.mem_offset + allocated_spec.allocated_memory,
214+
alignment,
215+
),
206216
prev_offset,
207217
)
208218
if spec.mem_offset is None:
209-
if prev_offset + spec.allocated_memory > get_size(
210-
memory_config, spec.mem_id
211-
):
219+
if get_aligned_offset(
220+
prev_offset + spec.allocated_memory, alignment
221+
) > get_size(memory_config, spec.mem_id):
212222
continue
213223
else:
214224
spec.mem_offset = prev_offset
@@ -423,6 +433,7 @@ def __init__(
423433
]
424434
]
425435
] = None,
436+
mem_alignment: int = 1,
426437
) -> None:
427438
self._init_mem_algos()
428439

@@ -433,6 +444,9 @@ def __init__(
433444
self.alloc_graph_output = alloc_graph_output
434445
self.additional_constraint_gen_passes = additional_constraint_gen_passes
435446

447+
assert mem_alignment > 0, "mem_alignment must be positive"
448+
self.mem_alignment = mem_alignment
449+
436450
def _init_mem_algos(self) -> None:
437451
self.available_mem_algos = [
438452
position_based_greedy_with_hierarchy,
@@ -459,6 +473,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
459473
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
460474
alloc_graph_input=self.alloc_graph_input,
461475
alloc_graph_output=self.alloc_graph_output,
476+
alignment=self.mem_alignment,
462477
)
463478
mem_planning(graph_module)
464479

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from executorch.backends.cadence.aot.pass_utils import count_node
1515
from executorch.exir import memory
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.memory_planning import collect_specs_from_nodes
1718
from executorch.exir.tests.models import MultiLayerPerceptron
1819

1920

2021
class TestMemPlanningPasses(unittest.TestCase):
21-
def test_calculate_peak_memory_pass(self):
22+
def test_calculate_peak_memory_pass(self) -> None:
2223
class PeakMemoryTestModel(torch.nn.Module):
2324
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
2425
super().__init__()
@@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor):
3233
x = self.linear2(x)
3334
return x
3435

35-
def calculate_aligned_num_bytes(num: int, alignment: int = 16):
36+
def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
3637
return math.ceil(num / alignment) * alignment
3738

3839
# model 1
@@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
8687
) # Align data on a 16 byte boundary
8788
self.assertEqual(peak_usage, expected_peak_usage)
8889

89-
def test_zero_memory_pass(self):
90+
def test_zero_memory_pass(self) -> None:
9091
class ZeroMem(torch.nn.Module):
9192
def forward(self, x):
9293
return x[:, 2::3, ...]
@@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
188189
f"{spec=} {arg_spec=}",
189190
)
190191

191-
def verify_nop_memory_alloc(self, graph_module):
192+
def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
192193
for node in graph_module.graph.find_nodes(
193194
op="call_function", target=torch.ops.aten._cat_nop.out
194195
):
@@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
204205
):
205206
self._verify_select_nop_memory_alloc(node)
206207

207-
def test_optimize_cat_on_placeholders(self):
208+
def test_optimize_cat_on_placeholders(self) -> None:
208209
class Cat(torch.nn.Module):
209210
def forward(self, x, y):
210211
return torch.ops.aten.cat((x, y))
@@ -228,7 +229,7 @@ def forward(self, x, y):
228229
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
229230
self.verify_nop_memory_alloc(graph_module)
230231

231-
def test_optimize_cat_outermost(self):
232+
def test_optimize_cat_outermost(self) -> None:
232233
class OptimizeCatFeasible1(torch.nn.Module):
233234
def forward(self, x, y):
234235
x1 = torch.add(x, 2.4, 3.1)
@@ -255,7 +256,7 @@ def forward(self, x, y):
255256
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
256257
self.verify_nop_memory_alloc(graph_module)
257258

258-
def test_optimize_cat_non_outermost(self):
259+
def test_optimize_cat_non_outermost(self) -> None:
259260
class OptimizeCatFeasible2(torch.nn.Module):
260261
def forward(self, x, y):
261262
x1 = torch.add(x, 2.4, 3.1)
@@ -282,7 +283,7 @@ def forward(self, x, y):
282283
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
283284
self.verify_nop_memory_alloc(graph_module)
284285

285-
def test_no_optimize_cat_non_outermost(self):
286+
def test_no_optimize_cat_non_outermost(self) -> None:
286287
class OptimizeCatInfeasible1(torch.nn.Module):
287288
def forward(self, x, y):
288289
x1 = torch.add(x, 2.4, 3.1)
@@ -308,7 +309,7 @@ def forward(self, x, y):
308309
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
309310
self.verify_nop_memory_alloc(graph_module)
310311

311-
def test_no_optimize_cat_non_outermost1(self):
312+
def test_no_optimize_cat_non_outermost1(self) -> None:
312313
class OptimizeCatInfeasible2(torch.nn.Module):
313314
def forward(self, x, y):
314315
x1 = torch.add(x, 2.4, 3.1)
@@ -335,7 +336,7 @@ def forward(self, x, y):
335336
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
336337
self.verify_nop_memory_alloc(graph_module)
337338

338-
def test_optimize_cat_with_slice(self):
339+
def test_optimize_cat_with_slice(self) -> None:
339340
class OptimizeCatSliceFeasible(torch.nn.Module):
340341
def forward(self, x):
341342
x1 = torch.add(x, 2.4, 3.1)
@@ -364,7 +365,7 @@ def forward(self, x):
364365
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
365366
self.verify_nop_memory_alloc(graph_module)
366367

367-
def test_optimize_cat_with_slice_infeasible(self):
368+
def test_optimize_cat_with_slice_infeasible(self) -> None:
368369
class OptimizeCatSliceInfeasible(torch.nn.Module):
369370
def forward(self, x, y):
370371
x1 = torch.add(x, 2.4, 3.1)
@@ -390,7 +391,7 @@ def forward(self, x, y):
390391
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
391392
self.verify_nop_memory_alloc(graph_module)
392393

393-
def test_optimize_slice_Tensor(self):
394+
def test_optimize_slice_Tensor(self) -> None:
394395
class SliceTensor(torch.nn.Module):
395396
def forward(self, x, y, z):
396397
x1 = torch.add(x, 2.4, 3.1)
@@ -452,7 +453,7 @@ def forward(self, x, y, z):
452453
)
453454
self.verify_nop_memory_alloc(graph_module)
454455

455-
def test_optimize_select_Tensor(self):
456+
def test_optimize_select_Tensor(self) -> None:
456457
class SelectTensor(torch.nn.Module):
457458
def forward(self, x, y, z):
458459
x1 = torch.add(x, 2.4, 3.1)
@@ -519,7 +520,7 @@ def forward(self, x, y, z):
519520

520521
# TODO: Test fails due to memory planning
521522
@unittest.expectedFailure
522-
def test_optimize_cat_with_param(self):
523+
def test_optimize_cat_with_param(self) -> None:
523524
class CatWithPadding(torch.nn.Module):
524525
def __init__(self, padding_shape):
525526
super().__init__()
@@ -547,7 +548,7 @@ def forward(self, x, y):
547548
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
548549
self.verify_nop_memory_alloc(graph_module)
549550

550-
def test_optimize_cat_then_slice_on_mutable_buffer(self):
551+
def test_optimize_cat_then_slice_on_mutable_buffer(self) -> None:
551552
class CatWithPadding(torch.nn.Module):
552553
def __init__(self, padding_shape):
553554
super().__init__()
@@ -572,7 +573,7 @@ def forward(self, x, y):
572573
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
573574
self.verify_nop_memory_alloc(graph_module)
574575

575-
def test_optimize_cat_with_view(self):
576+
def test_optimize_cat_with_view(self) -> None:
576577
class CatViewFeasible(torch.nn.Module):
577578
def forward(self, x, y):
578579
x1 = torch.add(x, 2.4, 3.1)
@@ -599,7 +600,7 @@ def forward(self, x, y):
599600
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
600601
self.verify_nop_memory_alloc(graph_module)
601602

602-
def test_no_optimize_cat_with_repeated_args(self):
603+
def test_no_optimize_cat_with_repeated_args(self) -> None:
603604
class CatViewInfeasible(torch.nn.Module):
604605
def forward(self, x):
605606
x1 = torch.add(x, 2.4, 3.1)
@@ -623,7 +624,7 @@ def forward(self, x):
623624
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
624625
self.verify_nop_memory_alloc(graph_module)
625626

626-
def test_no_optimize_cat_with_placeholder(self):
627+
def test_no_optimize_cat_with_placeholder(self) -> None:
627628
class CatViewInfeasible(torch.nn.Module):
628629
def forward(self, x, y):
629630
# Repeat will be decomposed into a cat. The cat cannot be optimized
@@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
741742
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
742743
self.verify_nop_memory_alloc(graph_module)
743744

744-
def test_view_for_unallocated_output(self):
745+
def test_view_for_unallocated_output(self) -> None:
745746
class Model(torch.nn.Module):
746747
def __init__(self, padding_shape):
747748
super().__init__()
@@ -764,3 +765,40 @@ def forward(self, x, y):
764765
)
765766
self.assertEqual(count_node(graph_module, memory.view), 1)
766767
self.verify_nop_memory_alloc(graph_module)
768+
769+
def test_start_alignment_constraints(self) -> None:
770+
class Model(torch.nn.Module):
771+
def __init__(self):
772+
super().__init__()
773+
774+
def forward(self, x: torch.Tensor, y: torch.Tensor):
775+
add_0 = torch.add(x, y)
776+
add_1 = torch.add(x, add_0)
777+
add_2 = torch.add(add_0, add_1)
778+
add_3 = torch.add(add_1, add_2)
779+
return add_3
780+
781+
model = Model()
782+
inputs = (torch.randn(4, 17), torch.randn(4, 17))
783+
for mem_algo in range(0, 2):
784+
graph_module = (
785+
compiler.export_to_executorch_gen_etrecord(
786+
model,
787+
inputs,
788+
opt_level=1,
789+
mem_algo=mem_algo,
790+
alloc_graph_input=False,
791+
alloc_graph_output=False,
792+
mem_alignment=37,
793+
)
794+
.exported_program()
795+
.graph_module
796+
)
797+
# Assert that all memory allocations are aligned to 32B start address
798+
for spec in collect_specs_from_nodes(
799+
graph_module.graph.nodes,
800+
ignore_graph_input=True,
801+
ignore_graph_output=True,
802+
):
803+
if spec and spec.mem_offset:
804+
self.assertEqual(spec.mem_offset % 37, 0)

backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
$if HAS_BIAS:
14+
#define HAS_BIAS
15+
1316
#define T ${buffer_scalar_type(DTYPE)}
1417

1518
${define_required_extensions(DTYPE)}
@@ -19,13 +22,17 @@ layout(std430) buffer;
1922
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
2023
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")}
2124
${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")}
25+
$if HAS_BIAS:
26+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")}
2227
${layout_declare_ubo(B, "ivec4", "out_sizes")}
2328
${layout_declare_ubo(B, "ivec4", "out_strides")}
2429
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
2530
${layout_declare_ubo(B, "ivec4", "mat1_strides")}
2631
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
2732
${layout_declare_ubo(B, "ivec4", "mat2_strides")}
2833
${layout_declare_ubo(B, "int", "out_numel")}
34+
$if HAS_BIAS:
35+
${layout_declare_ubo(B, "float", "alpha", "float", "beta")}
2936

3037
#include "indexing_utils.h"
3138

@@ -34,25 +41,25 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3441
${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")}
3542

3643
void main() {
37-
const ivec4 out_bufix = ivec4(
44+
const ivec4 out_tidx = ivec4(
3845
gl_GlobalInvocationID.x,
3946
gl_GlobalInvocationID.y,
4047
gl_GlobalInvocationID.z % out_sizes.z,
4148
gl_GlobalInvocationID.z / out_sizes.z);
4249

43-
if (any(greaterThanEqual(out_bufix, out_sizes))) {
50+
if (any(greaterThanEqual(out_tidx, out_sizes))) {
4451
return;
4552
}
4653

4754
int mat1_bufi = tidx_to_bufi(
48-
ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides);
55+
ivec4(0, out_tidx.y, out_tidx.z, out_tidx.w), mat1_strides);
4956
int mat2_bufi;
5057
if (mat2_is_transposed > 0) {
5158
mat2_bufi = tidx_to_bufi(
52-
ivec4(0, out_bufix.x, 0, 0), mat2_strides);
59+
ivec4(0, out_tidx.x, 0, 0), mat2_strides);
5360
} else {
5461
mat2_bufi = tidx_to_bufi(
55-
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
62+
ivec4(out_tidx.x, 0, out_tidx.z, out_tidx.w), mat2_strides);
5663
}
5764

5865
int mat2_stride;
@@ -70,6 +77,10 @@ void main() {
7077
mat2_bufi += mat2_stride;
7178
}
7279

73-
const int out_bufi = tidx_to_bufi(out_bufix, out_strides);
80+
const int out_bufi = tidx_to_bufi(out_tidx, out_strides);
81+
#ifdef HAS_BIAS
82+
t_out[out_bufi] = T(alpha) * T(sum) + T(beta) * t_bias[out_tidx.x];
83+
#else
7484
t_out[out_bufi] = T(sum);
85+
#endif // HAS_BIAS
7586
}

0 commit comments

Comments
 (0)