|
15 | 15 | from coremltools.converters.mil.frontend.torch.ops import ( |
16 | 16 | _get_inputs, |
17 | 17 | _get_kwinputs, |
| 18 | + noop, |
18 | 19 | NUM_TO_NUMPY_DTYPE, |
19 | 20 | NUM_TO_TORCH_DTYPE, |
20 | 21 | split, |
@@ -91,6 +92,28 @@ def _to_dim_order_copy(context, node): |
91 | 92 | to(context, node) |
92 | 93 |
|
93 | 94 |
|
| 95 | +@register_torch_op( |
| 96 | + torch_alias=[ |
| 97 | + "dim_order_ops::_clone_dim_order", |
| 98 | + "dim_order_ops._clone_dim_order", |
| 99 | + ], |
| 100 | + override=False, |
| 101 | +) |
| 102 | +def _clone_dim_order(context, node): |
| 103 | + dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0] |
| 104 | + node.kwinputs.pop("dim_order") |
| 105 | + |
| 106 | + # In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format. |
| 107 | + dim_order = [int(d) for d in dim_order.val] |
| 108 | + memory_format = get_memory_format(dim_order) |
| 109 | + assert ( |
| 110 | + memory_format == _torch.contiguous_format |
| 111 | + ), "Only contiguous memory format is supported in CoreML" |
| 112 | + |
| 113 | + # Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone. |
| 114 | + noop(context, node) |
| 115 | + |
| 116 | + |
94 | 117 | # https://github.com/apple/coremltools/pull/2558 |
95 | 118 | @register_torch_op( |
96 | 119 | torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], |
@@ -175,11 +198,22 @@ def dequantize_codebook(context, node): |
175 | 198 |
|
176 | 199 | # Assert codebook is as expected. codebook.dim() = codes.dim() + 2 |
177 | 200 | assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook" |
178 | | - assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported" |
179 | | - n_luts = codebook.shape[1] |
180 | | - assert ( |
181 | | - codes.shape[1] % n_luts == 0 |
182 | | - ), "codes.shape[1] must be divisible by codebook.shape[1]" |
| 201 | + assert (codebook.shape[0] == 1) or ( |
| 202 | + codebook.shape[1] == 1 |
| 203 | + ), "Only grouped_channel granularity is supported" |
| 204 | + if codebook.shape[0] == 1: |
| 205 | + # LUT is per column group |
| 206 | + n_luts = codebook.shape[1] |
| 207 | + assert ( |
| 208 | + codes.shape[1] % n_luts == 0 |
| 209 | + ), "codes.shape[1] must be divisible by codebook.shape[1]" |
| 210 | + else: |
| 211 | + # LUT is per row group |
| 212 | + n_luts = codebook.shape[0] |
| 213 | + assert ( |
| 214 | + codes.shape[0] % n_luts == 0 |
| 215 | + ), "codes.shape[0] must be divisible by codebook.shape[0]" |
| 216 | + |
183 | 217 | assert codebook.shape[2] == 2**nbits |
184 | 218 | assert codebook.shape[3] == 1, "Only scalar look up values are supported" |
185 | 219 |
|
|
0 commit comments