88# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
99# the op to the coremltools library.
1010
11- import torch as _torch
12- from coremltools import _logger
13- from coremltools .converters .mil .frontend import _utils
14- from coremltools .converters .mil .frontend .torch .ops import (
15- _get_inputs ,
16- _get_kwinputs ,
17- NUM_TO_NUMPY_DTYPE ,
18- NUM_TO_TORCH_DTYPE ,
19- split ,
20- to ,
21- transpose ,
22- unbind ,
23- )
2411import numpy as np
25- from coremltools .converters .mil .frontend .torch .torch_op_registry import (
26- register_torch_op ,
27- )
28- from coremltools .converters .mil .mil import types
29- from executorch .exir .dim_order_utils import get_memory_format
30-
31-
32- # https://github.com/apple/coremltools/pull/2556
33- @register_torch_op (override = False )
34- def transpose_copy (context , node ):
35- transpose (context , node )
36-
37-
38- # https://github.com/apple/coremltools/pull/2557
39- @register_torch_op (override = False )
40- def unbind_copy (context , node ):
41- unbind (context , node )
42-
43-
44- # https://github.com/apple/coremltools/pull/2563
45- @register_torch_op (override = False )
46- def split_copy (context , node ):
47- split (context , node )
48-
49-
50- @register_torch_op (
51- torch_alias = [
52- "dim_order_ops::_to_dim_order_copy" ,
53- "dim_order_ops._to_dim_order_copy" ,
54- ],
55- override = False ,
56- )
57- def _to_dim_order_copy (context , node ):
58- dim_order = _get_kwinputs (context , node , "dim_order" , default = [None ])[0 ]
59- node .kwinputs .pop ("dim_order" )
60-
61- # In CoreML, dim_order.val will be an ndarray, so we convert it to a list
62- dim_order = [int (d ) for d in dim_order .val ]
63- memory_format = get_memory_format (dim_order )
64- assert (
65- memory_format == _torch .contiguous_format
66- ), "Only contiguous memory format is supported in CoreML"
67- to (context , node )
68-
69-
70- # https://github.com/apple/coremltools/pull/2558
71- @register_torch_op (
72- torch_alias = ["torchao::dequantize_affine" , "torchao.dequantize_affine" ],
73- override = False ,
74- )
75- def dequantize_affine (context , node ):
76- inputs = _get_inputs (context , node , expected = [7 , 8 ])
77- int_data = inputs [0 ].val
78- block_size = inputs [1 ].val
79- scale = inputs [2 ].val
80- zero_point = (
81- inputs [3 ].val if inputs [3 ] is not None and inputs [3 ].val is not None else None
82- )
83- # I do not think we need to worry about input_dtype b/c it gets cast to int4/int8
84- # For now, we just check that it is int8 or int32
85- input_dtype = inputs [4 ].val # noqa: F841
86- assert NUM_TO_TORCH_DTYPE [input_dtype ] in [
87- _torch .int8 ,
88- _torch .int32 ,
89- ], "input_dtype should be int8 or int32"
90-
91- quant_min = inputs [5 ].val
92- quant_max = inputs [6 ].val
93-
94- assert len (int_data .shape ) == 2 , "dequantize_affine only supports rank 2 inputs"
95-
96- assert len (int_data .shape ) == len (
97- block_size
98- ), "block_size must have the same length as int_data.shape"
99- assert block_size [0 ] == 1 , "block_size[0] must be 1"
100- group_size = block_size [1 ]
101- k = int_data .shape [1 ]
102- assert k % group_size == 0 , "k must be divisible by group_size"
103- scales_per_row = k // group_size
104- scale = scale .reshape (- 1 , scales_per_row )
105- if zero_point is not None :
106- zero_point = zero_point .reshape (- 1 , scales_per_row )
107-
108- # TODO: I don't know if CoreML can make use of this
109- # We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass
110- # For now, we just log a warning
111- out_np_dtype = None
112- if len (inputs ) > 7 :
113- out_np_dtype = NUM_TO_NUMPY_DTYPE [inputs [7 ].val ]
114- _logger .warning (
115- f"Core ML ignores output_dtype { out_np_dtype } on torchao.dequantize_affine and instead uses the native precision."
116- )
117-
118- if quant_min == - 8 and quant_max == 7 :
119- quantized_np_dtype = types .nptype_from_builtin (types .string_to_builtin ("int4" ))
120- elif quant_min == - 128 and quant_max == 127 :
121- quantized_np_dtype = types .nptype_from_builtin (types .string_to_builtin ("int8" ))
122- else :
123- raise ValueError (
124- f"Unsupported quantization range: { quant_min } to { quant_max } . CoreML only supports 4-bit and 8-bit quantization."
125- )
126-
127- output = _utils ._construct_constexpr_dequant_op (
128- int_data .astype (quantized_np_dtype ),
129- zero_point ,
130- scale ,
131- axis = - 1 ,
132- name = node .name ,
133- )
134- context .add (output , node .name )
135-
136-
137-
138- # codes: torch.Tensor,
139- # codebook: torch.Tensor,
140- # code_dtype: torch.dtype,
141- # block_size: List[int],
142- # output_dtype: torch.dtype = torch.float32,
143-
144- # Copyright (c) Meta Platforms, Inc. and affiliates.
145- # All rights reserved.
146- #
147- # This source code is licensed under the BSD-style license found in the
148- # LICENSE file in the root directory of this source tree.
149-
150- # This file registers torch ops that are not yet in coremltools, or are in a more recent version of
151- # coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
152- # the op to the coremltools library.
153-
15412import torch as _torch
15513from coremltools import _logger
15614from coremltools .converters .mil .frontend import _utils
@@ -164,8 +22,6 @@ def dequantize_affine(context, node):
16422 transpose ,
16523 unbind ,
16624)
167- import numpy as np
168-
16925from coremltools .converters .mil .frontend .torch .torch_op_registry import (
17026 register_torch_op ,
17127)
@@ -283,24 +139,25 @@ def dequantize_affine(context, node):
283139 override = False ,
284140)
285141def dequantize_codebook (context , node ):
286- print ("IN DEQUANTIZE CODEBOOK" )
287142 inputs = _get_inputs (context , node , expected = [4 , 5 ])
288143 codes = inputs [0 ].val
289144 codebook = inputs [1 ].val
290- code_dtype = inputs [2 ].val
291- block_size = inputs [3 ].val
145+ nbits = inputs [2 ].val
292146
293-
294- assert len ( codes . shape ) == 2 , "Only rank 2 inputs are supported"
147+ # information in block_size is redundant with codebook.shape
148+ block_size = inputs [ 3 ]. val # noqa: F841
295149
296- # In TorchAO, the codebook shape is (n_lut, nbit, 1). The LUTs are for the columns.
297- # In CoreML, the expected shape is (lut_block_size, nbit, 1). 1 here is for scalar
298- # lut_block_size has the same rank as codes/idxs and tells you how many LUTs there are per block, e.g.,
299- # lut_block_size=(1, 8) means there is 1 LUT per 8 columns
300- assert len (codebook .shape ) == 3 , "Only rank 3 inputs are supported for codebook"
301- assert codebook .shape [- 1 ] == 1 , "we only support scalar palletization"
302- codebook = np .expand_dims (codebook , 0 )
150+ assert len (codes .shape ) == 2 , "Only rank 2 inputs are supported"
303151
152+ # Assert codebook is as expected. codebook.dim() = codes.dim() + 2
153+ assert len (codebook .shape ) == 4 , "Only rank 4 inputs are supported for codebook"
154+ assert codebook .shape [0 ] == 1 , "Only grouped_channel granularity is supported"
155+ n_luts = codebook .shape [1 ]
156+ assert (
157+ codes .shape [1 ] % n_luts == 0
158+ ), "codes.shape[1] must be divisible by codebook.shape[1]"
159+ assert codebook .shape [2 ] == 2 ** nbits
160+ assert codebook .shape [3 ] == 1 , "Only scalar look up values are supported"
304161
305162 if len (inputs ) > 4 :
306163 output_dtype = inputs [4 ].val
@@ -309,11 +166,6 @@ def dequantize_codebook(context, node):
309166 f"Core ML ignores output_dtype { out_np_dtype } on torchao.dequantize_affine and instead uses the native precision."
310167 )
311168
312- print ("codes" , codes .shape )
313- print ("codebook" , codebook .shape )
314- print ("code_dtype" , code_dtype )
315- print ("block_size" , block_size )
316-
317169 output = _utils ._construct_constexpr_lut_op (
318170 codes .astype (np .int8 ),
319171 codebook ,
0 commit comments