Skip to content

Commit 55853e6

Browse files
authored
fix circular import issue (#3445)
1 parent 18eeefa commit 55853e6

File tree

3 files changed

+31
-36
lines changed

3 files changed

+31
-36
lines changed

intel_extension_for_pytorch/cpu/tpp/fused_bert.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,19 @@
1313
from ...utils._logger import logger, WarningType
1414

1515
try:
16-
from transformers.modeling_utils import apply_chunking_to_forward
17-
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
16+
import transformers
17+
18+
transformers_orig_is_tensor = transformers.file_utils.is_tensor
19+
20+
def is_tensor(x):
21+
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`."""
22+
if transformers_orig_is_tensor(x):
23+
return True
24+
if isinstance(x, BlockedTensor):
25+
return True
26+
return False
27+
28+
transformers.file_utils.is_tensor = is_tensor
1829
except ImportError:
1930
pass
2031
USE_BF16_PARAMS = True
@@ -976,7 +987,7 @@ def forward(
976987
cross_attn_present_key_value = cross_attention_outputs[-1]
977988
present_key_value = present_key_value + cross_attn_present_key_value
978989

979-
layer_output = apply_chunking_to_forward(
990+
layer_output = transformers.modeling_utils.apply_chunking_to_forward(
980991
self.feed_forward_chunk,
981992
self.chunk_size_feed_forward,
982993
self.seq_len_dim,
@@ -1109,7 +1120,7 @@ def custom_forward(*inputs):
11091120
]
11101121
if v is not None
11111122
)
1112-
return BaseModelOutputWithPastAndCrossAttentions(
1123+
return transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions(
11131124
last_hidden_state=hidden_states,
11141125
past_key_values=next_decoder_cache,
11151126
hidden_states=all_hidden_states,
@@ -1178,23 +1189,6 @@ def forward(self, hidden_states):
11781189
# return bm_default_blocking_factors
11791190
# BlockedModule.default_blocking_factors = custom_blocking_factors
11801191

1181-
try:
1182-
import transformers
1183-
1184-
transformers_orig_is_tensor = transformers.file_utils.is_tensor
1185-
1186-
def is_tensor(x):
1187-
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`."""
1188-
if transformers_orig_is_tensor(x):
1189-
return True
1190-
if isinstance(x, BlockedTensor):
1191-
return True
1192-
return False
1193-
1194-
transformers.file_utils.is_tensor = is_tensor
1195-
except ImportError:
1196-
pass
1197-
11981192

11991193
def block(model):
12001194
for m in model.modules():

intel_extension_for_pytorch/llm/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@
2929
transformers.dynamic_module_utils.get_class_from_dynamic_module = (
3030
_get_class_from_dynamic_module
3131
)
32-
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_disable = (
33-
_gradient_checkpointing_disable
34-
)
35-
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
36-
_gradient_checkpointing_enable
37-
)
32+
from packaing import version
33+
34+
if version.parse(transformers.__version__) < version.parse("4.36"):
35+
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_disable = (
36+
_gradient_checkpointing_disable
37+
)
38+
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
39+
_gradient_checkpointing_enable
40+
)
3841
transformers.tokenization_utils_base.PreTrainedTokenizerBase.pad = _pad
3942
except ImportError:
4043
pass

intel_extension_for_pytorch/transformers/models/reference/models.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323
_prepare_4d_causal_attention_mask,
2424
)
2525

26-
if hasattr(transformers.models, "mixtral"):
27-
from transformers.models.mixtral.modeling_mixtral import (
28-
load_balancing_loss_func,
29-
)
3026
from transformers.modeling_outputs import (
3127
MoeCausalLMOutputWithPast,
3228
MoeModelOutputWithPast,
@@ -3277,10 +3273,12 @@ def MixtralForCausalLM_forward(
32773273

32783274
aux_loss = None
32793275
if output_router_logits:
3280-
aux_loss = load_balancing_loss_func(
3281-
outputs.router_logits if return_dict else outputs[-1],
3282-
self.num_experts,
3283-
self.num_experts_per_tok,
3276+
aux_loss = (
3277+
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func(
3278+
outputs.router_logits if return_dict else outputs[-1],
3279+
self.num_experts,
3280+
self.num_experts_per_tok,
3281+
)
32843282
)
32853283
if labels is not None:
32863284
loss += self.router_aux_loss_coef * aux_loss
@@ -5828,7 +5826,7 @@ def JambaForCausalLM_forward(
58285826

58295827
aux_loss = None
58305828
if output_router_logits:
5831-
aux_loss = load_balancing_loss_func(
5829+
aux_loss = transformers.models.jamba.modeling_jamba.load_balancing_loss_func(
58325830
outputs.router_logits if return_dict else outputs[-1],
58335831
self.num_experts,
58345832
self.num_experts_per_tok,

0 commit comments

Comments
 (0)