Skip to content

Commit efe4756

Browse files
committed
init
1 parent 3db27cd commit efe4756

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import coremltools as ct
1818
import coremltools.optimize as cto
1919
from executorch.backends.apple.coreml import executorchcoreml
20+
from numpy import isin
2021
from executorch.backends.apple.coreml.logging import get_coreml_log_level
2122
from executorch.exir.backend.backend_details import (
2223
BackendDetails,
@@ -37,6 +38,7 @@ class COMPILE_SPEC_KEYS(Enum):
3738
MIN_DEPLOYMENT_TARGET = "min_deployment_target"
3839
MODEL_COMPUTE_PRECISION = "model_compute_precision"
3940
OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
41+
CT_INPUTS = "ct_inputs"
4042

4143

4244
class MODEL_PATHS(Enum):
@@ -214,6 +216,152 @@ def op_linear_quantizer_config_from_compile_specs(
214216

215217
return None
216218

219+
220+
@staticmethod
221+
def generate_ct_inputs_compile_spec(
222+
ct_inputs: List[ct.TensorType],
223+
) -> CompileSpec:
224+
"""
225+
Returns the compile spec representing the model inputs
226+
Generally this is not needed, but is used to specify things that cannot be inferred from
227+
the exported program, like enumerated shapes
228+
"""
229+
230+
def _is_int_shape(seq):
231+
return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq)
232+
233+
def _serialize_shape(shape):
234+
# Case 1: None
235+
if shape is None:
236+
return {"kind": "fixed", "shape": None}
237+
238+
# Case 2: Plain list/tuple of ints
239+
if _is_int_shape(shape):
240+
return {"kind": "fixed", "shape": list(shape)}
241+
242+
# Case 3: EnumeratedShapes (with ct.Shape entries)
243+
if isinstance(shape, ct.EnumeratedShapes):
244+
shapes = []
245+
for s in shape.shapes:
246+
# ct.Shape(...) -> s.shape should be a tuple of ints
247+
if not _is_int_shape(s.shape):
248+
raise TypeError("EnumeratedShapes entries must be tuples/lists of ints")
249+
shapes.append(list(s.shape))
250+
default = None
251+
if shape.default is not None:
252+
if not _is_int_shape(shape.default.shape):
253+
raise TypeError("EnumeratedShapes.default must be a tuple/list of ints")
254+
default = list(shape.default.shape)
255+
return {"kind": "enumerated", "shapes": shapes, "default": default}
256+
257+
# Anything else is out of scope for now
258+
raise TypeError("Shape must be EnumeratedShapes, a list/tuple of ints, or None")
259+
260+
def tensor_type_to_dict(t: ct.TensorType):
261+
assert isinstance(t, ct.TensorType)
262+
for attr in ["name", "dtype", "default_value"]:
263+
assert getattr(t, attr) is None, f"{attr} cannot be given a value"
264+
return {
265+
"kind": "TensorType",
266+
"name": t.name,
267+
"shape": _serialize_shape(t.shape),
268+
"dtype": t.dtype,
269+
"default_value": t.default_value,
270+
}
271+
272+
str_representation = json.dumps([tensor_type_to_dict(ct_in) for ct_in in ct_inputs])
273+
byte_representation = str_representation.encode("utf-8")
274+
return CompileSpec(
275+
COMPILE_SPEC_KEYS.CT_INPUTS.value,
276+
byte_representation,
277+
)
278+
279+
@staticmethod
280+
def ct_inputs_from_compile_specs(
281+
compile_specs: List["CompileSpec"],
282+
) -> Optional[List[ct.TensorType]]:
283+
"""
284+
Returns the model's ct.inputs by parsing the list of compile specs.
285+
286+
Expected JSON schema per entry (as produced by generate_ct_inputs_compile_spec):
287+
{
288+
"kind": "TensorType",
289+
"name": "<non-empty string>",
290+
"shape": {
291+
"kind": "fixed", "shape": [int, ...] | null
292+
| "kind": "enumerated", "shapes": [[int, ...], ...], "default": [int, ...] | null
293+
}
294+
}
295+
"""
296+
def _is_int_shape(seq):
297+
return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq)
298+
299+
def _parse_shape(shape_json):
300+
if not isinstance(shape_json, dict) or "kind" not in shape_json:
301+
raise ValueError("Invalid shape JSON: missing 'kind'")
302+
303+
kind = shape_json["kind"]
304+
305+
# Case: fixed
306+
if kind == "fixed":
307+
shp = shape_json.get("shape", None)
308+
if shp is None:
309+
return None
310+
if not _is_int_shape(shp):
311+
raise TypeError("Fixed shape must be a list/tuple of ints or null")
312+
return tuple(shp)
313+
314+
# Case: enumerated
315+
if kind == "enumerated":
316+
shapes = shape_json.get("shapes", None)
317+
if not isinstance(shapes, list) or not shapes:
318+
raise ValueError("Enumerated shape must have non-empty 'shapes' list")
319+
320+
parsed_shapes = []
321+
for s in shapes:
322+
if not _is_int_shape(s):
323+
raise TypeError("EnumeratedShapes entries must be lists of ints")
324+
parsed_shapes.append(ct.Shape(tuple(s)))
325+
326+
default = shape_json.get("default", None)
327+
default_shape = None
328+
if default is not None:
329+
if not _is_int_shape(default):
330+
raise TypeError("EnumeratedShapes.default must be a list of ints")
331+
default_shape = ct.Shape(tuple(default))
332+
333+
return ct.EnumeratedShapes(shapes=parsed_shapes, default=default_shape)
334+
335+
raise ValueError(f"Unsupported shape kind: {kind}")
336+
337+
for compile_spec in compile_specs:
338+
if compile_spec.key == COMPILE_SPEC_KEYS.CT_INPUTS.value:
339+
raw = compile_spec.value.decode("utf-8")
340+
payload = json.loads(raw)
341+
342+
if not isinstance(payload, list):
343+
raise ValueError("CT_INPUTS payload must be a list")
344+
345+
ct_inputs: List[ct.TensorType] = []
346+
for entry in payload:
347+
if not isinstance(entry, dict) or entry.get("kind") != "TensorType":
348+
raise ValueError("Each entry must be a dict with kind == 'TensorType'")
349+
350+
name = entry.get("name", "")
351+
if not isinstance(name, str) or not name:
352+
raise ValueError("TensorType.name must be a non-empty string")
353+
354+
shape_json = entry.get("shape", None)
355+
shape = _parse_shape(shape_json) if shape_json is not None else None
356+
357+
# Per your current contract, dtype/default_value must be None (and were omitted).
358+
# So we only pass name + shape here.
359+
ct_inputs.append(ct.TensorType(name=name, shape=shape))
360+
361+
return ct_inputs
362+
363+
return None
364+
217365
@staticmethod
218366
def generate_compile_specs(
219367
compute_unit: ct.ComputeUnit = ct.ComputeUnit.ALL,
@@ -446,6 +594,9 @@ def preprocess(
446594
op_linear_quantizer_config = (
447595
CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs)
448596
)
597+
enumerated_shapes = (
598+
CoreMLBackend.enumerate_shapes_from_compile_specs(compile_specs)
599+
)
449600

450601
# Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
451602
# get_compiled_model_path() requires a loaded model.

0 commit comments

Comments
 (0)