Skip to content

Commit b56eaa5

Browse files
committed
Update base for Update on "[slimtensor] Introduce Device and ScalarType headers for SlimTensor minimal support"
This diff introduces the foundational c10 core headers for SlimTensor, a lightweight tensor implementation used by torchnative, to cuda backend runtime and further it will be used by all aoti-driven backends like MPS. We add: - DeviceType.h - Device type enum (CPU only for now) - Device.h - Device class representing compute device location - ScalarType.h - Scalar type enum with elementSize() helper (Float only for now) These headers are modeled after PyTorch's c10 but simplified for our needs. The enum values are kept compatible with PyTorch for serialization compatibility. This is the first step in migrating SlimTensor to replace ETensor as the internal tensor representation in CUDA backend. Future diffs will add Storage, SlimTensor class, and additional dtypes/devices incrementally. Differential Revision: [D89747061](https://our.internmc.facebook.com/intern/diff/D89747061/) [ghstack-poisoned]
2 parents 757f5dc + 9ba1b5d commit b56eaa5

File tree

145 files changed

+412
-1297
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+412
-1297
lines changed

backends/aoti/targets.bzl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def define_common_targets():
4949
supports_python_dlopen = True,
5050
# Constructor needed for backend registration.
5151
compiler_flags = ["-Wno-global-constructors"],
52-
visibility = ["@EXECUTORCH_CLIENTS"],
52+
visibility = ["PUBLIC"],
5353
deps = [
5454
"//executorch/runtime/core:core",
5555
"//executorch/runtime/core/exec_aten:lib",
@@ -67,7 +67,7 @@ def define_common_targets():
6767
supports_python_dlopen = True,
6868
# Constructor needed for backend registration.
6969
compiler_flags = ["-Wno-global-constructors"],
70-
visibility = ["@EXECUTORCH_CLIENTS"],
70+
visibility = ["PUBLIC"],
7171
deps = [
7272
"//executorch/runtime/backend:interface",
7373
"//executorch/runtime/core:core",
@@ -80,7 +80,7 @@ def define_common_targets():
8080
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
8181
link_whole = True,
8282
supports_python_dlopen = True,
83-
visibility = ["@EXECUTORCH_CLIENTS"],
83+
visibility = ["PUBLIC"],
8484
exported_deps = [
8585
":common_shims",
8686
":delegate_handle",

backends/apple/coreml/TARGETS

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ oncall("executorch")
88
# TODO: this is a placeholder to support internal fbcode build. We should add the coreml backend target properly.
99
runtime.python_library(
1010
name = "coreml",
11-
visibility = [
12-
"@EXECUTORCH_CLIENTS",
13-
],
11+
visibility = ["PUBLIC"],
1412
)
1513

1614
runtime.python_library(
@@ -19,9 +17,7 @@ runtime.python_library(
1917
"compiler/*.py",
2018
"logging.py",
2119
]),
22-
visibility = [
23-
"@EXECUTORCH_CLIENTS",
24-
],
20+
visibility = ["PUBLIC"],
2521
deps = [
2622
"fbsource//third-party/pypi/coremltools:coremltools",
2723
":executorchcoreml",
@@ -36,9 +32,7 @@ runtime.python_library(
3632
"partition/*.py",
3733
"logging.py",
3834
]),
39-
visibility = [
40-
"@EXECUTORCH_CLIENTS",
41-
],
35+
visibility = ["PUBLIC"],
4236
deps = [
4337
"fbsource//third-party/pypi/coremltools:coremltools",
4438
":backend",
@@ -55,9 +49,7 @@ runtime.python_library(
5549
srcs = glob([
5650
"quantizer/*.py",
5751
]),
58-
visibility = [
59-
"@EXECUTORCH_CLIENTS",
60-
],
52+
visibility = ["PUBLIC"],
6153
)
6254

6355
runtime.python_library(
@@ -66,10 +58,7 @@ runtime.python_library(
6658
"recipes/__init__.py",
6759
"recipes/coreml_recipe_provider.py"
6860
],
69-
visibility = [
70-
"@EXECUTORCH_CLIENTS",
71-
"//executorch/export/...",
72-
],
61+
visibility = ["PUBLIC"],
7362
deps = [
7463
"fbsource//third-party/pypi/coremltools:coremltools",
7564
":coreml_recipe_types",
@@ -91,10 +80,7 @@ runtime.python_library(
9180
srcs = [
9281
"recipes/coreml_recipe_types.py",
9382
],
94-
visibility = [
95-
"@EXECUTORCH_CLIENTS",
96-
"//executorch/export/...",
97-
],
83+
visibility = ["PUBLIC"],
9884
deps = [
9985
"//executorch/export:recipe",
10086
],
@@ -124,10 +110,7 @@ runtime.cxx_python_extension(
124110
types = [
125111
"executorchcoreml.pyi",
126112
],
127-
visibility = [
128-
"//executorch/examples/apple/coreml/...",
129-
"@EXECUTORCH_CLIENTS",
130-
],
113+
visibility = ["PUBLIC"],
131114
deps = [
132115
"fbsource//third-party/nlohmann-json:nlohmann-json",
133116
"fbsource//third-party/pybind11:pybind11",

backends/apple/mps/TARGETS

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ runtime.python_library(
1919
"__init__.py",
2020
"mps_preprocess.py",
2121
],
22-
visibility = [
23-
"@EXECUTORCH_CLIENTS",
24-
],
22+
visibility = ["PUBLIC"],
2523
deps = [
2624
":operators",
2725
":serialization",
@@ -49,9 +47,7 @@ runtime.python_library(
4947
srcs = glob([
5048
"partition/*.py",
5149
]),
52-
visibility = [
53-
"@EXECUTORCH_CLIENTS",
54-
],
50+
visibility = ["PUBLIC"],
5551
deps = [
5652
":backend",
5753
"//caffe2:torch",

backends/apple/mps/targets.bzl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,7 @@ def define_common_targets(is_xplat = False, platforms = []):
3939
"runtime/operations/*.h",
4040
]),
4141
"srcs": MPS_BACKEND_BUCK_SRCS,
42-
"visibility": [
43-
"//executorch/backends/apple/...",
44-
"//executorch/examples/...",
45-
"//executorch/exir/backend:backend_lib",
46-
"//executorch/extension/pybindings/...",
47-
"//executorch/runtime/backend/...",
48-
"//executorch/devtools/runners/...",
49-
"//executorch/test/...",
50-
"@EXECUTORCH_CLIENTS",
51-
],
42+
"visibility": ["PUBLIC"],
5243
"link_whole": True,
5344
}
5445

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def call(self, graph_module: torch.fx.GraphModule):
9090
args = node.args
9191
meta = node.meta
9292
match len(args):
93+
case 6:
94+
# torch.ops.aten.layer_norm.default has 6 args:
95+
# (input, normalized_shape, weight, bias, eps, cudnn_enable)
96+
# cudnn_enable is not used in the decomposition
97+
x, normalized_shape, weights, bias, epsilon, _cudnn_enable = args
9398
case 5:
9499
x, normalized_shape, weights, bias, epsilon = args
95100
case 4:

backends/arm/runtime/targets.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def define_common_targets():
55
name = "vela_bin_stream",
66
srcs = ["VelaBinStream.cpp"],
77
exported_headers = ["VelaBinStream.h"],
8-
visibility = ["@EXECUTORCH_CLIENTS"],
8+
visibility = ["PUBLIC"],
99
deps = [
1010
"//executorch/runtime/core:core",
1111
],
@@ -21,7 +21,7 @@ def define_common_targets():
2121
supports_python_dlopen = True,
2222
# Constructor needed for backend registration.
2323
compiler_flags = ["-Wno-global-constructors"],
24-
visibility = ["@EXECUTORCH_CLIENTS"],
24+
visibility = ["PUBLIC"],
2525
deps = [
2626
"//executorch/runtime/backend:interface",
2727
":vela_bin_stream",

backends/arm/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def define_arm_tests():
1919
"ops/test_avg_pool2d.py",
2020
"ops/test_cat.py",
2121
"ops/test_conv2d.py",
22-
"ops/test_linear.py",
22+
"ops/test_linear.py",
2323
"ops/test_mul.py",
2424
"ops/test_permute.py",
2525
"ops/test_rsqrt.py",

backends/cadence/aot/reorder_ops.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ def advancing_feasible(self, quant_node: torch.fx.Node):
299299
# All the conditions satisfied, we advance.
300300
return True
301301

302-
def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
302+
def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
303303
graph = graph_module.graph
304+
modified = False
304305
for node in reversed(graph.nodes):
305306
if get_overload_packet(node.target) not in (
306307
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
@@ -339,15 +340,19 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
339340
# We can safely remove the quant node and trivially quantizable op
340341
graph.erase_node(node)
341342
graph.erase_node(trivially_quantizable_op)
343+
modified = True
342344

343-
graph_module.recompile()
344-
graph_module.graph.eliminate_dead_code()
345+
return modified
345346

346347
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
347348
self.graph_module = graph_module
348-
self.advance_quantize_op(graph_module)
349-
result = super().call(graph_module)
350-
return result
349+
modified = self.advance_quantize_op(graph_module)
350+
if modified:
351+
graph_module.recompile()
352+
graph_module.graph.eliminate_dead_code()
353+
return super().call(graph_module)
354+
355+
return PassResult(graph_module, False)
351356

352357

353358
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -474,14 +479,21 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
474479
# the graph (up to 3 times max, to avoid potential infinite loops)
475480
self.graph_module = graph_module
476481
iter_count = 0
477-
modified = True
482+
local_modified = False
483+
overall_modified = False
484+
485+
while local_modified or iter_count == 0:
486+
local_modified = self.postpone_dequantize_op(self.graph_module)
487+
overall_modified |= local_modified
488+
489+
if local_modified:
490+
self.graph_module = super().call(self.graph_module).graph_module
478491

479-
while modified and iter_count < 3:
480-
modified = self.postpone_dequantize_op(self.graph_module)
481-
self.graph_module = super().call(self.graph_module).graph_module
482492
iter_count += 1
493+
if iter_count == 3:
494+
break
483495

484-
return super().call(self.graph_module)
496+
return PassResult(self.graph_module, overall_modified)
485497

486498

487499
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,14 @@ def test_advance_branched_quantize(self) -> None:
286286
@torch.no_grad()
287287
def test_advance_quantize(self) -> None:
288288
builder = GraphBuilder()
289-
x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32))
290-
weights = builder.placeholder(
291-
"weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8)
292-
)
289+
x_data = torch.randn(16, 1, 32, 6, dtype=torch.float32)
290+
weight_data = torch.randint(-128, 127, (32, 32), dtype=torch.int8)
291+
x = builder.placeholder("x", x_data)
292+
weights = builder.placeholder("weights", weight_data)
293293
full = builder.call_operator(
294294
op=exir_ops.edge.aten.full.default,
295295
args=([1], -7),
296+
kwargs={"dtype": torch.int32},
296297
)
297298
full_1 = builder.call_operator(
298299
op=exir_ops.edge.aten.full.default,
@@ -304,7 +305,8 @@ def test_advance_quantize(self) -> None:
304305
)
305306
full_3 = builder.call_operator(
306307
op=exir_ops.edge.aten.full.default,
307-
args=([12], 0.0),
308+
args=([1], 0),
309+
kwargs={"dtype": torch.int32},
308310
)
309311
permute = builder.call_operator(
310312
op=exir_ops.edge.aten.permute_copy.default,
@@ -337,8 +339,13 @@ def test_advance_quantize(self) -> None:
337339

338340
p1 = AdvanceQuantizeOpAboveDefInBranchPass()
339341
tmp_graph = cast(PassResult, p1(original_graph)).graph_module
340-
p2 = AdvanceQuantizeOpAboveDefChainPass()
341-
converted_graph = cast(PassResult, p2(tmp_graph)).graph_module
342+
result = transform_and_check_numerics(
343+
tmp_graph,
344+
(x_data, weight_data),
345+
AdvanceQuantizeOpAboveDefChainPass(),
346+
)
347+
self.assertFalse(result.modified)
348+
converted_graph = result.graph_module
342349
# Assert that permute node is now the successor of the quant node.
343350
self.assertTrue(
344351
get_node_pos(
@@ -349,13 +356,14 @@ def test_advance_quantize(self) -> None:
349356

350357
def test_postpone_dequantize1(self) -> None:
351358
builder = GraphBuilder()
352-
x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32))
353-
weights = builder.placeholder(
354-
"weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8)
355-
)
359+
x_data = torch.randn(1, 16, 32, 6, dtype=torch.float32)
360+
weight_data = torch.randint(-128, 127, (6, 6), dtype=torch.int8)
361+
x = builder.placeholder("x", x_data)
362+
weights = builder.placeholder("weights", weight_data)
356363
full = builder.call_operator(
357364
op=exir_ops.edge.aten.full.default,
358365
args=([1], -7),
366+
kwargs={"dtype": torch.int32},
359367
)
360368
full_1 = builder.call_operator(
361369
op=exir_ops.edge.aten.full.default,
@@ -367,7 +375,8 @@ def test_postpone_dequantize1(self) -> None:
367375
)
368376
full_3 = builder.call_operator(
369377
op=exir_ops.edge.aten.full.default,
370-
args=([12], 0.0),
378+
args=([1], 0),
379+
kwargs={"dtype": torch.int32},
371380
)
372381
quantize_per_tensor = builder.call_operator(
373382
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
@@ -397,8 +406,13 @@ def test_postpone_dequantize1(self) -> None:
397406
)
398407
builder.output([permute])
399408
original_graph = builder.get_graph_module()
400-
p = PostponeDequantizeOpBelowUseChainPass()
401-
converted_graph = cast(PassResult, p(original_graph)).graph_module
409+
result = transform_and_check_numerics(
410+
original_graph,
411+
(x_data, weight_data),
412+
PostponeDequantizeOpBelowUseChainPass(),
413+
)
414+
self.assertTrue(result.modified)
415+
converted_graph = result.graph_module
402416
# Assert that dequant node is now the successor of the permute node.
403417
self.assertTrue(
404418
get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)

backends/cadence/fusion_g3/operators/targets.bzl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@ def define_operator(name: str, deps: list[str] | None = None) -> None:
2323
name = op_name,
2424
srcs = [op_name + ".cpp"],
2525
platforms = CXX,
26-
visibility = [
27-
"//executorch/backends/cadence/...",
28-
"@EXECUTORCH_CLIENTS",
29-
],
26+
visibility = ["PUBLIC"],
3027
compatible_with = ["ovr_config//cpu:xtensa"],
3128
deps = deps + common_deps,
3229
)

0 commit comments

Comments
 (0)