Skip to content

Commit 629a2e3

Browse files
authored
Disable tpp for un-verified models (#822)
1 parent 07ada34 commit 629a2e3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

optimum/intel/ipex/modeling_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def ipex_jit_trace(model, task, use_cache):
110110
sample_inputs.pop("past_key_values")
111111

112112
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
113-
# Only ipex >= 2.3.0 supports tpp.
114-
if is_ipex_version(">=", "2.3.0"):
113+
# Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
114+
if is_ipex_version(">=", "2.3.0") and task in _IPEX_EXPORTED_GENERATION_TASKS:
115115
_enable_tpp()
116116
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
117117
# Disable repack while jit tracing to reduce the memory

0 commit comments

Comments
 (0)