Skip to content

Commit 11bae28

Browse files
minettekaumjohnrachwan123llcnt
authored
feat: add compilation algorithms (#443)
* fixed conflicts * fixed linting errors * fixing more linting errors * removing uv.lock * fixing linting error * fixing linting error * fixing intel version * fix: ty to match main * Co-author Co-authored-by: John Rachwan <johnrachwan@gmail.com> Co-authored-by: Louis Leconte <louis.leconte@ens-paris-saclay.fr> --------- Co-authored-by: John Rachwan <johnrachwan@gmail.com> Co-authored-by: Louis Leconte <louis.leconte@ens-paris-saclay.fr>
1 parent 3b5772e commit 11bae28

File tree

6 files changed

+652
-1
lines changed

6 files changed

+652
-1
lines changed

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ name = "pruna_internal"
7070
url = "https://prunaai.pythonanywhere.com/simple/" # Pruna Pythonanywhere
7171
default = true # default = True makes this index the lowest prio
7272

73+
[[tool.uv.index]]
74+
name = "intel-pytorch-extension"
75+
url = "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/"
76+
77+
[tool.uv]
78+
index-strategy = "unsafe-best-match"
79+
7380
[tool.uv.sources]
7481
gptqmodel = [
7582
{ index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'"},
@@ -187,6 +194,9 @@ dev = [
187194
"pytest-xdist>=3.8.0",
188195
]
189196
cpu = []
197+
intel = [
198+
"intel-extension-for-pytorch>=2.7.0",
199+
]
190200

191201
[build-system]
192202
requires = ["hatchling"]

src/pruna/algorithms/ipex_llm.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import re
15+
from pathlib import Path
16+
from typing import Any, Dict
17+
18+
import torch
19+
from ConfigSpace import OrdinalHyperparameter
20+
21+
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
22+
from pruna.algorithms.base.tags import AlgorithmTag
23+
from pruna.config.smash_config import SmashConfigPrefixWrapper
24+
from pruna.engine.save import SAVE_FUNCTIONS
25+
from pruna.logging.logger import pruna_logger
26+
27+
28+
class IPEXLLM(PrunaAlgorithmBase):
29+
"""
30+
Implement IPEX LLM compilation using the intel library.
31+
32+
This compiler leverages advanced graph optimizations, quantization, and kernel fusion techniques to accelerate
33+
PyTorch-based LLM inference on Intel CPUs.
34+
35+
Note: After compilation, the model supports sequence lengths that are either ≤ 32, or even numbers.
36+
"""
37+
38+
algorithm_name: str = "ipex_llm"
39+
group_tags: list[AlgorithmTag] = [AlgorithmTag.COMPILER]
40+
references: dict[str, str] = {"Github": "https://github.com/intel/intel-extension-for-pytorch"}
41+
tokenizer_required: bool = False
42+
processor_required: bool = False
43+
dataset_required: bool = False
44+
save_fn = SAVE_FUNCTIONS.save_before_apply
45+
runs_on: list[str] = ["cpu"]
46+
compatible_before: list[str] = ["half"]
47+
required_install = (
48+
"``pip install pruna[intel]`` "
49+
"``--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/``"
50+
)
51+
52+
def get_hyperparameters(self) -> list:
53+
"""
54+
Get the hyperparameters for IPEX LLM compilation.
55+
56+
Returns
57+
-------
58+
list
59+
The hyperparameters.
60+
"""
61+
return [
62+
OrdinalHyperparameter(
63+
"weight_bits",
64+
sequence=[8, 4],
65+
default_value=8,
66+
meta=dict(desc="The number of bits to use for weight quantization."),
67+
),
68+
]
69+
70+
def model_check_fn(self, model: Any) -> bool:
71+
"""
72+
Check if the model is compatible with IPEX LLM compilation.
73+
74+
Parameters
75+
----------
76+
model : Any
77+
The model to check.
78+
79+
Returns
80+
-------
81+
bool
82+
Whether the model is compatible with IPEX LLM compilation.
83+
"""
84+
imported_modules = self.import_algorithm_packages()
85+
# Find the installation path of ipex
86+
ipex_path = Path(imported_modules["ipex"].__file__).parent
87+
# Try to find the models.py file
88+
transformers_path = ipex_path / "transformers"
89+
# Find the full path of models.py if it exists
90+
models_path = transformers_path / "models" / "reference" / "models.py"
91+
if models_path.exists():
92+
# Read the function names from the file
93+
with open(models_path, "r") as f:
94+
content = f.read()
95+
# Simple regex to find function definitions
96+
funcs = [f for f in re.findall(r"def\s+([A-Z][a-zA-Z0-9_]*)\s*\(", content) if f.endswith("_forward")]
97+
compatible_list = [name.replace("_forward", "") for name in funcs]
98+
return model.__class__.__name__ in compatible_list
99+
else:
100+
pruna_logger.warning("IPEX models.py file not found. Please check if IPEX is installed correctly.")
101+
return False
102+
103+
def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
104+
"""
105+
Compile the model with IPEX LLM.
106+
107+
Parameters
108+
----------
109+
model : Any
110+
The model to compile.
111+
smash_config : SmashConfigPrefixWrapper
112+
The configuration to use for compilation.
113+
114+
Returns
115+
-------
116+
Any
117+
The compiled model.
118+
"""
119+
imported_modules = self.import_algorithm_packages()
120+
ipex = imported_modules["ipex"]
121+
woq_weight_dtype = imported_modules["WoqWeightDtype"]
122+
123+
weight_dtype = woq_weight_dtype.INT8 if smash_config["weight_bits"] == 8 else woq_weight_dtype.INT4
124+
125+
lowp_mode = ipex.quantization.WoqLowpMode.INT8
126+
127+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(weight_dtype=weight_dtype, lowp_mode=lowp_mode)
128+
129+
model = ipex.llm.optimize(
130+
model.eval(),
131+
dtype=getattr(torch, "float32"),
132+
quantization_config=qconfig,
133+
low_precision_checkpoint=None,
134+
deployment_mode=True,
135+
inplace=True,
136+
)
137+
138+
return model
139+
140+
def import_algorithm_packages(self) -> Dict[str, Any]:
141+
"""
142+
Import the algorithm packages.
143+
144+
Returns
145+
-------
146+
Dict[str, Any]
147+
The algorithm packages.
148+
"""
149+
# Import necessary modules here to avoid unnecessary imports and ensure they're available when needed
150+
import intel_extension_for_pytorch as ipex
151+
from intel_extension_for_pytorch.quantization import WoqWeightDtype
152+
153+
return dict(
154+
ipex=ipex,
155+
WoqWeightDtype=WoqWeightDtype,
156+
)

0 commit comments

Comments
 (0)