Skip to content

Commit 69b5ea5

Browse files
authored
Add IPEX patch fusion linear for bert and vit (#786)
* add bert for question answering * fix check * add ViT model for image classification * fix tasks * fix variable name * add test patching * add vit patching tests * fix name * skip testing patch if ipex < 2.3 * fix traced model patch check
1 parent 48cc82a commit 69b5ea5

File tree

4 files changed

+111
-38
lines changed

4 files changed

+111
-38
lines changed

optimum/exporters/ipex/model_patcher.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from transformers.models.bert.modeling_bert import BertIntermediate
1516
from transformers.models.llama.modeling_llama import (
1617
LlamaDecoderLayer,
1718
LlamaForCausalLM,
1819
LlamaModel,
1920
LlamaRMSNorm,
2021
)
22+
from transformers.models.vit.modeling_vit import ViTIntermediate
2123

2224
from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version
2325

2426
from .modeling_utils import (
2527
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
28+
_ipex_rms_layer_norm_forward,
29+
_IPEXIntermediate,
2630
_IPEXLlamaDecoderLayer,
27-
_llama_layer_norm_forward,
2831
_llama_model_forward,
2932
)
3033

@@ -33,8 +36,7 @@
3336
_TRANSFORMERS_MIN_VERSION = "4.39.0"
3437
_TRANSFORMERS_MAX_VERSION = "4.41.2"
3538

36-
_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
37-
_IPEX_EXPORTED_TASK = ("text-generation",)
39+
_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)
3840

3941

4042
def convert_func(m, func_name, new_function):
@@ -49,7 +51,7 @@ def convert_functions(m, target_m, new_function_name, new_function):
4951
convert_functions(sub_m, target_m, new_function_name, new_function)
5052

5153

52-
def convert_class(m, target_m, new_class, config):
54+
def convert_class(m, target_m, new_class, config=None):
5355
for name, sub_m in m.named_children():
5456
if isinstance(sub_m, target_m):
5557
new_m = new_class(sub_m, config)
@@ -65,6 +67,23 @@ def patch_op(m, target_m, new_op_name, new_op):
6567

6668

6769
def _patch_llama_model(model):
70+
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
71+
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward)
72+
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
73+
return model
74+
75+
76+
def _patch_bert_model(model):
77+
convert_class(model, BertIntermediate, _IPEXIntermediate)
78+
return model
79+
80+
81+
def _patch_vit_model(model):
82+
convert_class(model, ViTIntermediate, _IPEXIntermediate)
83+
return model
84+
85+
86+
def _patch_model(model):
6887
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
6988
raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching")
7089
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(
@@ -73,13 +92,10 @@ def _patch_llama_model(model):
7392
raise ImportError(
7493
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
7594
)
76-
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
77-
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)
78-
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
79-
return model
80-
81-
82-
def _patch_model(model):
8395
if isinstance(model, LlamaForCausalLM):
8496
model = _patch_llama_model(model)
97+
elif model.config.model_type == "bert":
98+
model = _patch_bert_model(model)
99+
elif model.config.model_type == "vit":
100+
model = _patch_vit_model(model)
85101
return model

optimum/exporters/ipex/modeling_utils.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
import math
1617
from typing import List, Optional, Tuple, Union
1718

@@ -25,11 +26,27 @@
2526
from optimum.intel.utils.modeling_utils import _setattr_from_module
2627

2728

29+
logger = logging.getLogger(__name__)
30+
2831
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0"
2932

3033

34+
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
35+
logger.warning(
36+
f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model."
37+
)
38+
else:
39+
from intel_extension_for_pytorch.llm.modules import (
40+
IndirectAccessKVCacheAttention,
41+
Linear2SiluMul,
42+
LinearAdd,
43+
LinearGelu,
44+
RotaryEmbedding,
45+
)
46+
47+
3148
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
32-
def _llama_layer_norm_forward(self, hidden_states):
49+
def _ipex_rms_layer_norm_forward(self, hidden_states):
3350
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)
3451

3552

@@ -139,14 +156,9 @@ def _llama_model_forward(
139156
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
140157
class _IPEXLlamaAttention(nn.Module):
141158
def __init__(self, module, config) -> None:
142-
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
143-
raise ImportError(
144-
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding"
145-
)
146159
super().__init__()
147160
_setattr_from_module(self, module)
148161
self.config = config
149-
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding
150162

151163
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
152164
self.mha_linear_add = LinearAdd(module.o_proj)
@@ -296,14 +308,9 @@ def forward(
296308
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186
297309
class _IPEXLlamaMLP(nn.Module):
298310
def __init__(self, module, config) -> None:
299-
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
300-
raise ImportError(
301-
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul, LinearAdd"
302-
)
303311
super().__init__()
304312
_setattr_from_module(self, module)
305313
self.config = config
306-
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd
307314

308315
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
309316
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
@@ -398,3 +405,16 @@ def forward(
398405
outputs += (present_key_value,)
399406

400407
return outputs
408+
409+
410+
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
411+
class _IPEXIntermediate(nn.Module):
412+
def __init__(self, module, config):
413+
super().__init__()
414+
_setattr_from_module(self, module)
415+
self.linear_gelu = LinearGelu(module.dense)
416+
del self.__dict__["_modules"]["dense"]
417+
418+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
419+
hidden_states = self.linear_gelu(hidden_states)
420+
return hidden_states

optimum/intel/ipex/modeling_base.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
from optimum.modeling_base import OptimizedModel
5252
from optimum.utils import NormalizedConfigManager
5353

54-
from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model
54+
from ...exporters.ipex.model_patcher import (
55+
_IPEX_EXPORTED_GENERATION_TASKS,
56+
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
57+
_patch_model,
58+
)
5559
from ..generation.modeling import prepare_jit_inputs
5660
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
5761
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device
@@ -60,7 +64,7 @@
6064
logger = logging.getLogger(__name__)
6165

6266

63-
_IPEX_SUPPORT_MODEL_TYPES = ("llama",)
67+
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit")
6468
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
6569

6670

@@ -70,17 +74,22 @@ def _is_patched_with_ipex(model, task):
7074

7175
if isinstance(model, torch.jit.ScriptModule):
7276
for node in model.graph.nodes():
73-
# Jit will record the codes position so we can check if the node use ipex exporter.
74-
if "torch_ipex::rotary_position_embedding" in node.__str__():
77+
# Only patched model enabled fusion linear.
78+
if "/fusions/" in node.__str__():
7579
return True
7680
return False
77-
else:
81+
elif task in _IPEX_EXPORTED_GENERATION_TASKS and model.config.hidden_size < 64:
7882
# The ipex IAKV op in patched model requires the hidden size at least 64
79-
return (
80-
model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES
81-
and task in _IPEX_EXPORTED_TASK
82-
and model.config.hidden_size >= 64
83-
)
83+
return False
84+
85+
return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES
86+
87+
88+
def _prepare_inputs_for_ipex_model(model, task, use_cache):
89+
if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex(model, task):
90+
return get_dummy_input(model, return_dict=True)
91+
else:
92+
return prepare_jit_inputs(model, task, use_cache)
8493

8594

8695
def ipex_jit_trace(model, task, use_cache):
@@ -90,12 +99,8 @@ def ipex_jit_trace(model, task, use_cache):
9099

91100
if _is_patched_with_ipex(model, task):
92101
model = _patch_model(model)
93-
# TODO: integerate in prepare_jit_inputs.
94-
sample_inputs = get_dummy_input(model, return_dict=True)
95-
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
96-
_enable_tpp()
97-
else:
98-
sample_inputs = prepare_jit_inputs(model, task, use_cache)
102+
103+
sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache)
99104

100105
model.config.return_dict = False
101106

@@ -104,6 +109,8 @@ def ipex_jit_trace(model, task, use_cache):
104109
if not use_cache:
105110
sample_inputs.pop("past_key_values")
106111

112+
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
113+
_enable_tpp()
107114
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
108115
# Disable repack while jit tracing to reduce the memory
109116
ipex._C.disable_jit_linear_repack()

tests/ipex/test_modeling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,21 @@ def test_pipeline(self, model_arch):
188188
self.assertGreaterEqual(outputs["score"], 0.0)
189189
self.assertIsInstance(outputs["answer"], str)
190190

191+
@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
192+
def test_patched_model(self):
193+
ipex_model = IPEXModelForQuestionAnswering.from_pretrained(
194+
"Jiqing/patched_tiny_random_bert_for_question_answering"
195+
)
196+
transformers_model = AutoModelForQuestionAnswering.from_pretrained("hf-internal-testing/tiny-random-bert")
197+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
198+
inputs = "This is a sample input"
199+
tokens = tokenizer(inputs, return_tensors="pt")
200+
with torch.no_grad():
201+
transformers_outputs = transformers_model(**tokens)
202+
outputs = ipex_model(**tokens)
203+
self.assertTrue(torch.allclose(outputs.start_logits, transformers_outputs.start_logits, atol=1e-4))
204+
self.assertTrue(torch.allclose(outputs.end_logits, transformers_outputs.end_logits, atol=1e-4))
205+
191206

192207
class IPEXModelForCausalLMTest(unittest.TestCase):
193208
IPEX_MODEL_CLASS = IPEXModelForCausalLM
@@ -458,3 +473,18 @@ def test_pipeline(self, model_arch):
458473
self.assertEqual(pipe.device, model.device)
459474
self.assertGreaterEqual(outputs[0]["score"], 0.0)
460475
self.assertTrue(isinstance(outputs[0]["label"], str))
476+
477+
@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
478+
def test_patched_model(self):
479+
ipex_model = IPEXModelForImageClassification.from_pretrained(
480+
"Jiqing/patched_tiny_random_vit_for_image_classification"
481+
)
482+
transformers_model = self.IPEX_MODEL_CLASS.from_pretrained("hf-internal-testing/tiny-random-vit")
483+
preprocessor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-vit")
484+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
485+
image = Image.open(requests.get(url, stream=True).raw)
486+
inputs = preprocessor(images=image, return_tensors="pt")
487+
with torch.no_grad():
488+
transformers_outputs = transformers_model(**inputs)
489+
outputs = ipex_model(**inputs)
490+
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))

0 commit comments

Comments
 (0)