Skip to content

Commit 9187fb0

Browse files
committed
Update on "make et.export support etrecord generation"
this diff makes et.export etrecord generation supportive. Details can be found in #12925. After this change, all things in #12925 has completed. Differential Revision: [D79741917](https://our.internmc.facebook.com/intern/diff/D79741917/) [ghstack-poisoned]
2 parents 6593cc3 + 8003cb0 commit 9187fb0

File tree

14 files changed

+268
-29
lines changed

14 files changed

+268
-29
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
99
# the op to the coremltools library.
1010

11+
import numpy as np
1112
import torch as _torch
1213
from coremltools import _logger
1314
from coremltools.converters.mil.frontend import _utils
@@ -21,7 +22,6 @@
2122
transpose,
2223
unbind,
2324
)
24-
2525
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
2626
register_torch_op,
2727
)
@@ -132,3 +132,43 @@ def dequantize_affine(context, node):
132132
name=node.name,
133133
)
134134
context.add(output, node.name)
135+
136+
137+
@register_torch_op(
138+
torch_alias=["quant::dequantize_codebook", "quant.dequantize_codebook"],
139+
override=False,
140+
)
141+
def dequantize_codebook(context, node):
142+
inputs = _get_inputs(context, node, expected=[4, 5])
143+
codes = inputs[0].val
144+
codebook = inputs[1].val
145+
nbits = inputs[2].val
146+
147+
# information in block_size is redundant with codebook.shape
148+
block_size = inputs[3].val # noqa: F841
149+
150+
assert len(codes.shape) == 2, "Only rank 2 inputs are supported"
151+
152+
# Assert codebook is as expected. codebook.dim() = codes.dim() + 2
153+
assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook"
154+
assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported"
155+
n_luts = codebook.shape[1]
156+
assert (
157+
codes.shape[1] % n_luts == 0
158+
), "codes.shape[1] must be divisible by codebook.shape[1]"
159+
assert codebook.shape[2] == 2**nbits
160+
assert codebook.shape[3] == 1, "Only scalar look up values are supported"
161+
162+
if len(inputs) > 4:
163+
output_dtype = inputs[4].val
164+
out_np_dtype = NUM_TO_NUMPY_DTYPE[output_dtype]
165+
_logger.warning(
166+
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
167+
)
168+
169+
output = _utils._construct_constexpr_lut_op(
170+
codes.astype(np.int8),
171+
codebook,
172+
name=node.name,
173+
)
174+
context.add(output, node.name)

backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,17 @@
8888
ET_LOG(Error, "%s: DataType=%d is not supported", ETCoreMLStrings.delegateIdentifier.UTF8String, (int)tensor.scalar_type());
8989
return std::nullopt;
9090
}
91-
91+
9292
std::vector<ssize_t> strides(tensor.strides().begin(), tensor.strides().end());
9393
std::vector<size_t> shape(tensor.sizes().begin(), tensor.sizes().end());
94-
94+
9595
// If tensor is rank 0, wrap in rank 1
9696
// See https://github.com/apple/coremltools/blob/8.2/coremltools/converters/mil/frontend/torch/exir_utils.py#L73
9797
if (shape.size() == 0) {
9898
shape.push_back(1);
9999
strides.push_back(1);
100100
}
101-
101+
102102
MultiArray::MemoryLayout layout(dataType.value(), std::move(shape), std::move(strides));
103103
switch (argType) {
104104
case ArgType::Input: {
@@ -281,9 +281,11 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {
281281
}
282282

283283
namespace {
284-
auto cls = CoreMLBackendDelegate();
285-
Backend backend{ETCoreMLStrings.delegateIdentifier.UTF8String, &cls};
286-
static auto success_with_compiler = register_backend(backend);
284+
#ifndef LAZY_LOAD_IOS_PYTORCH_INITIALIZER
285+
auto cls = CoreMLBackendDelegate();
286+
Backend backend{ETCoreMLStrings.delegateIdentifier.UTF8String, &cls};
287+
static auto success_with_compiler = register_backend(backend);
288+
#endif
287289
}
288290

289291
} // namespace coreml

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1616
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.exir.backend.utils import format_delegated_graph
18+
19+
from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig
1720
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
1821

1922

@@ -164,6 +167,61 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
164167
et_prog = delegated_program.to_executorch()
165168
self._compare_outputs(et_prog, model, example_inputs)
166169

170+
def test_dequantize_codebook_linear(self):
171+
model, example_inputs = self._get_test_model()
172+
quantize_(
173+
model,
174+
CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[-1, 16]),
175+
)
176+
ep = torch.export.export(model, example_inputs)
177+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
178+
delegated_program = executorch.exir.to_edge_transform_and_lower(
179+
ep,
180+
partitioner=[self._coreml_partitioner()],
181+
)
182+
for node in delegated_program.exported_program().graph.nodes:
183+
if node.op == "call_function":
184+
assert node.target.__name__ in [
185+
"executorch_call_delegate",
186+
"getitem",
187+
], f"Got unexpected node target after delegation: {node.target.__name__}"
188+
189+
assert (
190+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
191+
in format_delegated_graph(delegated_program.exported_program().graph_module)
192+
)
193+
194+
et_prog = delegated_program.to_executorch()
195+
self._compare_outputs(et_prog, model, example_inputs)
196+
197+
def test_dequantize_codebook_embedding(self):
198+
model, example_inputs = self._get_test_model()
199+
quantize_(
200+
model,
201+
CodebookWeightOnlyConfig(dtype=torch.uint3, block_size=[-1, 16]),
202+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
203+
)
204+
ep = torch.export.export(model, example_inputs)
205+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
206+
delegated_program = executorch.exir.to_edge_transform_and_lower(
207+
ep,
208+
partitioner=[self._coreml_partitioner()],
209+
)
210+
for node in delegated_program.exported_program().graph.nodes:
211+
if node.op == "call_function":
212+
assert node.target.__name__ in [
213+
"executorch_call_delegate",
214+
"getitem",
215+
], f"Got unexpected node target after delegation: {node.target.__name__}"
216+
217+
assert (
218+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
219+
in format_delegated_graph(delegated_program.exported_program().graph_module)
220+
)
221+
222+
et_prog = delegated_program.to_executorch()
223+
self._compare_outputs(et_prog, model, example_inputs)
224+
167225

168226
if __name__ == "__main__":
169227
test_runner = TestTorchOps()
@@ -172,3 +230,5 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
172230
test_runner.test_dequantize_affine_c4w_embedding()
173231
test_runner.test_dequantize_affine_c4w_linear()
174232
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
233+
test_runner.test_dequantize_codebook_linear()
234+
test_runner.test_dequantize_codebook_embedding()

backends/xnnpack/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ foreach(fbs_file ${_xnnpack_schema__srcs})
5959
)
6060
endforeach()
6161

62-
if(WIN32)
62+
if(WIN32 AND NOT CMAKE_CROSSCOMPILING)
6363
set(MV_COMMAND
6464
powershell -Command
6565
"Move-Item -Path ${_xnnpack_flatbuffer__outputs} -Destination ${_xnnpack_schema__outputs}"

devtools/etrecord/tests/TARGETS

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,7 @@ python_unittest(
77
name = "etrecord_test",
88
srcs = ["etrecord_test.py"],
99
deps = [
10-
"//caffe2:torch",
11-
"//executorch/devtools/bundled_program:config",
12-
"//executorch/devtools/bundled_program:core",
13-
"//executorch/devtools/etrecord:etrecord",
14-
"//executorch/exir:lib",
15-
"//executorch/exir/tests:models",
16-
"//executorch/export:lib",
10+
":etrecord_test_library"
1711
],
1812
)
1913

@@ -27,5 +21,6 @@ python_library(
2721
"//executorch/devtools/etrecord:etrecord",
2822
"//executorch/exir:lib",
2923
"//executorch/exir/tests:models",
24+
"//executorch/export:lib",
3025
],
3126
)

extension/wasm/CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ list(
3737
embind
3838
executorch_core
3939
extension_data_loader
40-
portable_ops_lib
4140
extension_module_static
4241
extension_tensor
4342
extension_runner_util
@@ -49,8 +48,12 @@ target_compile_options(executorch_wasm PUBLIC ${_common_compile_options})
4948
target_include_directories(
5049
executorch_wasm PUBLIC ${_common_include_directories}
5150
)
52-
target_link_libraries(executorch_wasm PUBLIC ${link_libraries})
51+
target_link_libraries(
52+
executorch_wasm
53+
PUBLIC ${link_libraries}
54+
INTERFACE executorch_kernels
55+
)
5356

54-
if(EXECUTORCH_BUILD_WASM_TESTS)
57+
if(BUILD_TESTING)
5558
add_subdirectory(test)
5659
endif()

extension/wasm/README.md

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# ExecuTorch Wasm Extension
2+
3+
This directory contains the source code for the ExecuTorch Wasm extension. The extension is a C++ library that provides a JavaScript API for ExecuTorch models. The extension is compiled to WebAssembly and can be used in JavaScript applications.
4+
5+
## Installing Emscripten
6+
7+
[Emscripten](https://emscripten.org/index.html) is necessary to compile ExecuTorch for Wasm. You can install Emscripten with these commands:
8+
9+
```bash
10+
# Clone the emsdk repository
11+
git clone https://github.com/emscripten-core/emsdk.git
12+
cd emsdk
13+
14+
# Download and install version 4.0.10 of the SDK
15+
./emsdk install 4.0.10
16+
./emsdk activate 4.0.10
17+
18+
# Add the Emscripten environment variables to your shell
19+
source ./emsdk_env.sh
20+
```
21+
22+
## Building ExecuTorch for Wasm
23+
24+
To build ExecuTorch for Wasm, make sure to use the `emcmake cmake` command and to have `EXECUTORCH_BUILD_WASM` enabled. For example:
25+
26+
```bash
27+
# Configure the build with the Emscripten environment variables
28+
emcmake cmake . -DEXECUTORCH_BUILD_WASM=ON \
29+
-DCMAKE_BUILD_TYPE=Release \
30+
-Bcmake-out-wasm
31+
32+
# Build the Wasm extension
33+
cmake --build cmake-out-wasm --target executorch_wasm -j32
34+
```
35+
36+
To reduce the binary size, you may also use the selective build options found in the [Kernel Library Selective Build guide](../../docs/source/kernel-library-selective-build.md). You may also use optimized kernels with the `EXECUTORCH_BUILD_KERNELS_OPTIMIZED` option. Portable kernels are used by default.
37+
38+
### Building for Web
39+
40+
In your CMakeLists.txt, add the following lines:
41+
42+
```cmake
43+
add_executable(executorch_wasm_lib) # Emscripten outputs this as a JS and Wasm file
44+
target_link_libraries(executorch_wasm_lib PRIVATE executorch_wasm)
45+
target_link_options(executorch_wasm_lib PRIVATE ...) # Add any additional link options here
46+
```
47+
48+
You can find the Emscripten link options in the [emcc reference](https://emscripten.org/docs/tools_reference/emcc.html).
49+
50+
Building this should output `executorch_wasm_lib.js` and `executorch_wasm_lib.wasm` in the build directory. You can then use this file in your page.
51+
52+
```html
53+
<script>
54+
// Emscripten calls Module.onRuntimeInitialized once the runtime is ready.
55+
var Module = {
56+
onRuntimeInitialized: function() {
57+
const et = Module; // Assign Module into et for ease of use
58+
const model = et.Module.load("mv2.pte");
59+
// ...
60+
}
61+
}
62+
</script>
63+
<script src="executorch_wasm_lib.js"></script>
64+
```
65+
66+
### Building for Node.js
67+
68+
While the standard way to import a module in Node.js is to use the `require` function, doing so does not give you access to the [Emscripten API](https://emscripten.org/docs/api_reference/index.html) which would be stored in the globals. For example, you may want to use the [File System API](https://emscripten.org/docs/api_reference/Filesystem-API.html) in your unit tests, which cannot be done if the library is loaded with `require`. Instead, you can use the `--pre-js` option to prepend your file to the start of the JS output and behave similarly to the example in the [Web build](#building-for-web).
69+
70+
```cmake
71+
add_executable(my_project) # Emscripten outputs this as a JS and Wasm file
72+
target_link_libraries(my_project PRIVATE executorch_wasm)
73+
target_link_options(my_project PRIVATE --pre-js my_code.js) # Add any additional link options here
74+
```
75+
76+
The output `my_project.js` should contain both the emitted JS code and the contents of `my_code.js` prepended.
77+
78+
## JavaScript API
79+
80+
### Module
81+
- `static load(data)`: Load a model from a file or a buffer.
82+
- `getMethods()`: Returns the list of methods in the model.
83+
- `loadMethod(methodName)`: Load a method from the model.
84+
- `getMethodMetadata(methodName)`: Get the metadata of a method.
85+
- `execute(methodName, inputs)`: Execute a method with the given inputs.
86+
- `forward(inputs)`: Execute the forward method with the given inputs.
87+
- `delete()`: Delete the model from memory.
88+
89+
### Tensor
90+
- `static zeroes(shape, dtype=ScalarType.Float)`: Create a tensor of zeros with the given shape and dtype.
91+
- `static ones(shape, dtype=ScalarType.Float)`: Create a tensor of ones with the given shape and dtype.
92+
- `static full(shape, value, dtype=ScalarType.Float)`: Create a tensor of the given value with the given shape and dtype
93+
- `static fromArray(shape, array, dtype=ScalarType.Float, dimOrder=[], strides=[])`: Create a tensor from a JavaScript array.
94+
- `static fromIter(shape, iter, dtype=ScalarType.Float, dimOrder=[], strides=[])`: Create a tensor from an iterable.
95+
- `delete()`: Delete the tensor from memory.
96+
- `scalarType`: The scalar type of the tensor.
97+
- `data`: The data buffer of the tensor.
98+
- `sizes`: The sizes of the tensor.
99+
100+
### MethodMeta
101+
- `name`: The name of the method.
102+
- `inputTags`: The input tags of the method.
103+
- `inputTensorMeta`: The input tensor metadata of the method.
104+
- `outputTags`: The output tags of the method.
105+
- `outputTensorMeta`: The output tensor metadata of the method.
106+
- `attributeTensorMeta`: The attribute tensor metadata of the method.
107+
- `memoryPlannedBufferSizes`: The memory planned buffer sizes of the method.
108+
- `backends`: The backends of the method.
109+
- `numInstructions`: The number of instructions in the method.
110+
- These are value types and do not need to be manually deleted.
111+
112+
### TensorInfo
113+
- `sizes`: The sizes of the tensor.
114+
- `dimOrder`: The dimension order of the tensor.
115+
- `scalarType`: The scalar type of the tensor.
116+
- `isMemoryPlanned`: Whether the tensor is memory planned.
117+
- `nBytes`: The number of bytes in the tensor.
118+
- `name`: The name of the tensor.
119+
- These are value types and do not need to be manually deleted.
120+
121+
### ScalarType
122+
- Only `Float` and `Long` are currently supported.
123+
- `value`: The int constant value of the enum.
124+
- `name`: The `ScalarType` as a string.
125+
126+
### Tag
127+
- `value`: The int constant value of the enum.
128+
- `name`: The `Tag` as a string.
129+
130+
Emscripten's JavaScript API is also avaiable, which you can find more information about it in their [API Reference](https://emscripten.org/docs/api_reference/index.html).

extension/wasm/test/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ add_custom_target(
4141
)
4242

4343
add_executable(executorch_wasm_tests)
44-
target_link_libraries(executorch_wasm_tests PUBLIC executorch_wasm)
44+
target_link_libraries(executorch_wasm_tests PRIVATE executorch_wasm)
4545
target_link_options(
4646
executorch_wasm_tests
47-
PUBLIC
47+
PRIVATE
4848
--embed-file
4949
"${MODELS_DIR}@/"
50-
--post-js
50+
--pre-js
5151
${CMAKE_CURRENT_SOURCE_DIR}/unittests.js
5252
-sASSERTIONS=2
5353
)

extension/wasm/test/unittests.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
let et;
9+
var Module = {};
10+
const et = Module;
1011
beforeAll((done) => {
11-
et = Module;
1212
et.onRuntimeInitialized = () => {
1313
done();
1414
}

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ dependencies=[
7272
"typing-extensions>=4.10.0",
7373
# Keep this version in sync with: ./backends/apple/coreml/scripts/install_requirements.sh
7474
"coremltools==8.3; platform_system == 'Darwin' or platform_system == 'Linux'",
75+
# scikit-learn is used to support palettization in the coreml backend
76+
"scikit-learn==1.7.1",
7577
"hydra-core>=1.3.0",
7678
"omegaconf>=2.3.0",
7779
]

0 commit comments

Comments
 (0)