Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig
from torchao.utils import unwrap_tensor_subclass


Expand Down Expand Up @@ -77,7 +78,7 @@ def main() -> None:
parser.add_argument(
"--coreml-quantize",
default=None,
choices=["b4w", "c4w"],
choices=["b4w", "c4w", "custom",],
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)",
)
parser.add_argument(
Expand Down Expand Up @@ -118,6 +119,8 @@ def main() -> None:
model.eval()
model.to(float_dtype)

print("MODEL", model)

if export_args.target_split_size is not None:
replace_linear_with_split_linear(
model,
Expand Down Expand Up @@ -163,6 +166,34 @@ def main() -> None:
granularity=PerAxis(0),
),
)
elif export_args.coreml_quantize == "custom":
replace_linear_with_split_linear(
model,
out_target_split_size=2048,
out_max_splits=4,
in_target_split_size=1,
in_max_splits=1,
fqn_filer=lambda fqn: any(fqn.endswith(suffix) for suffix in ["w1", "w3"])
)
replace_linear_with_split_linear(
model,
out_target_split_size=2048,
out_max_splits=1,
in_target_split_size=2048,
in_max_splits=4,
fqn_filer=lambda fqn: any(fqn.endswith(suffix) for suffix in ["w2"])
)
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerAxis(0),
),
lambda m, fqn: (
isinstance(m, torch.nn.Linear)
and any(fqn.endswith(suffix) for suffix in ["wq", "wk", "wv", "wo", "output", "w1", "w3", "w2"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be suffix in fqn instead of fqn.endswith(suffix) in order to capture the split linear modules. This is because once split linear is applied, the fqn no longer ends with w1, w2, and w3.

),
)

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=ct.target.iOS18,
Expand Down Expand Up @@ -199,6 +230,10 @@ def main() -> None:
print("Exported program")
print(ep)

# ep = ep.run_decompositions({})
# mlprogram = ct.convert(ep, minimum_deployment_target=ct.target.iOS18)
# mlprogram.save("model.mlpackage")

edge_manager = to_edge_transform_and_lower(
ep,
partitioner=[partitioner],
Expand Down
45 changes: 40 additions & 5 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,40 @@ def forward(self, x):
output = self._norm(x)
return output * self.weight

class CoreMLRMSNormV2(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def forward(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""

return torch.nn.functional.rms_norm(x, normalized_shape=[self.dim], weight=self.weight, eps=None)

_RMS_NORM = CoreMLRMSNorm

class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
Expand Down Expand Up @@ -327,8 +361,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
if self.use_qk_norm:
q_norm_dim = self.head_dim
k_norm_dim = self.head_dim
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
self.q_norm_fn = _RMS_NORM(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = _RMS_NORM(k_norm_dim, eps=args.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -364,6 +398,7 @@ def forward(
k = torch.concat([k_cache, k], dim=2)
v = torch.concat([v_cache, v], dim=2)

# TODO: I'm pretty sure the MB version of SDPA does not require this repeat_interleave,
# grouped multiquery attention: expand out keys and values
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
Expand All @@ -388,8 +423,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
self.block_sparse_moe = MOEFeedForward(args)
else:
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps)
self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -422,7 +457,7 @@ def __init__(self, params: ModelArgs):
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params, self.rope))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.norm = _RMS_NORM(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
Expand Down
10 changes: 8 additions & 2 deletions examples/apple/coreml/llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,15 @@ def forward(self, x):


def replace_linear_with_split_linear(
model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1
model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1, fqn_filer=None,
):
if fqn_filer is None:
fqn_filer = lambda fqn: True

for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
should_split = isinstance(module, torch.nn.Linear) and fqn_filer(name)
print("TESTING", name, "WILL SPLIT", should_split)
if should_split:
assert module.bias is None, "SplitLinearModule does not support bias"
new_module = SplitLinearModule(
module.in_features,
Expand All @@ -113,4 +118,5 @@ def replace_linear_with_split_linear(
out_max_splits,
in_target_split_size,
in_max_splits,
fqn_filer,
)
2 changes: 1 addition & 1 deletion examples/apple/coreml/scripts/extract_coreml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def extract_coreml_models(pte_data: bytes):
program = deserialize_pte_binary(pte_data)
program = deserialize_pte_binary(pte_data).program
delegates: List[BackendDelegate] = sum(
[execution_plan.delegates for execution_plan in program.execution_plan], []
)
Expand Down
Loading