Skip to content

Commit 6a5a01e

Browse files
authored
remove XPULinearXXX class definition for ipex (#1212)
* remove useless code Signed-off-by: Liu, Kaixuan <[email protected]> * fix xpu dtype tests error Signed-off-by: Liu, Kaixuan <[email protected]> --------- Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent a20051d commit 6a5a01e

File tree

2 files changed

+21
-127
lines changed

2 files changed

+21
-127
lines changed

optimum/exporters/ipex/modeling_utils.py

Lines changed: 18 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
logger = logging.getLogger(__name__)
3434

35-
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
35+
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.6.0"
3636
_accelerate_added_attributes = ["to", "xpu"]
3737

3838

@@ -52,92 +52,6 @@
5252
)
5353

5454

55-
# TODO: Following XPULinearXXX op classes will be put into ipex after 2.6.0 version
56-
class XPULinear2SiluMul(torch.nn.Module):
57-
def __init__(
58-
self,
59-
gate_proj: torch.nn.Module,
60-
up_proj: torch.nn.Module,
61-
):
62-
super().__init__()
63-
self.gate_proj_weight = gate_proj.weight.transpose(0, 1).contiguous()
64-
self.up_proj_weight = up_proj.weight.transpose(0, 1).contiguous()
65-
self.gate_proj_bias = gate_proj.bias
66-
self.up_proj_bias = up_proj.bias
67-
68-
def forward(
69-
self,
70-
hidden_states,
71-
):
72-
up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight)
73-
if self.gate_proj_bias is not None:
74-
up += self.gate_proj_bias
75-
hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up)
76-
if self.up_proj_bias is not None:
77-
hidden_states += self.up_proj_bias
78-
return hidden_states
79-
80-
81-
class XPULinearGelu(torch.nn.Module):
82-
def __init__(self, module: torch.nn.Module):
83-
super().__init__()
84-
self.weight = module.weight.transpose(0, 1).contiguous()
85-
self.bias = module.bias
86-
87-
def forward(self, x):
88-
return torch.ops.torch_ipex.matmul_gelu(x, self.weight, self.bias, 1.0, "tanh")
89-
90-
91-
class XPULinearAdd(torch.nn.Module):
92-
def __init__(
93-
self,
94-
module: torch.nn.Module,
95-
):
96-
super().__init__()
97-
self.weight = module.weight.transpose(0, 1).contiguous()
98-
self.bias = module.bias
99-
100-
def forward(
101-
self,
102-
hidden_states,
103-
residual,
104-
):
105-
token_len, _ = hidden_states.size()
106-
if residual is None:
107-
hidden_states = torch.matmul(hidden_states, self.weight)
108-
if self.bias is not None:
109-
hidden_states += self.bias
110-
else:
111-
if self.bias is not None:
112-
hidden_states = torch.ops.torch_ipex.mm_bias_resadd(
113-
hidden_states, self.weight, self.bias, 1.0, residual, 1.0
114-
)
115-
else:
116-
hidden_states = torch.addmm(
117-
residual.flatten(0, -2),
118-
hidden_states.flatten(0, -2),
119-
self.weight,
120-
beta=1.0,
121-
)
122-
hidden_states = hidden_states.view(token_len, -1)
123-
return hidden_states
124-
125-
126-
class XPUlinearAddAdd(torch.nn.Module):
127-
def __init__(self, module: torch.nn.Module):
128-
super().__init__()
129-
self.weight = module.weight.transpose(0, 1).contiguous()
130-
self.bias = module.bias
131-
132-
def forward(self, x, y, z):
133-
if self.bias is not None:
134-
x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, self.bias, 1.0, y, 1.0)
135-
x += z
136-
else:
137-
x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, z, 1.0, y, 1.0)
138-
return x
139-
140-
14155
# Adapted from https://github.com/huggingface/accelerate/blob/v1.2.1/src/accelerate/hooks.py#L183
14256
def _remove_hooks_for_ipex(module, recurse):
14357
if hasattr(module, "_hf_hook"):
@@ -885,11 +799,9 @@ def __init__(self, module, device, config) -> None:
885799
self.q_slice = self.q_proj.weight.shape[0]
886800
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
887801
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
802+
888803
if not config.compile and module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
889-
if self.module_device.type == "cpu":
890-
self.mha_linear_add = LinearAdd(module.o_proj)
891-
elif self.module_device.type == "xpu":
892-
self.mha_linear_add = XPULinearAdd(module.o_proj)
804+
self.mha_linear_add = LinearAdd(module.o_proj)
893805

894806
def qkv_gemm(self, hidden_states):
895807
if hasattr(self, "concat_qkv"):
@@ -940,13 +852,8 @@ def __init__(self, module, device, config) -> None:
940852
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
941853
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
942854
self.c_proj_linear.bias = self.c_proj.bias
943-
if self.module_device.type == "cpu":
944-
if self.c_proj_linear not in ["LinearAllreduce"]:
945-
self.linear_add = LinearAdd(self.c_proj_linear)
946-
947-
elif self.module_device.type == "xpu":
948-
if self.c_proj_linear not in ["LinearAllreduce"]:
949-
self.linear_add = XPULinearAdd(self.c_proj_linear)
855+
if self.c_proj_linear not in ["LinearAllreduce"]:
856+
self.linear_add = LinearAdd(self.c_proj_linear)
950857

951858
def qkv_gemm(self, hidden_states):
952859
if hasattr(self, "c_attn_linear"):
@@ -977,17 +884,12 @@ def __init__(self, module, device, config) -> None:
977884
_setattr_from_module(self, module)
978885
self.config = config
979886
self.module_device = device
887+
980888
if not config.compile and getattr(config, "quantization_config", None) is None:
981-
if self.module_device.type == "cpu":
982-
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
983-
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
984-
self.mlp_linear_add = LinearAdd(module.down_proj)
985-
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
986-
elif self.module_device.type == "xpu":
987-
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
988-
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
989-
self.mlp_linear_add = XPULinearAdd(module.down_proj)
990-
self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj)
889+
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
890+
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
891+
self.mlp_linear_add = LinearAdd(module.down_proj)
892+
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
991893

992894
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
993895
if hasattr(self, "linear_silu_mul"):
@@ -1012,15 +914,10 @@ def __init__(self, module, device, config) -> None:
1012914
self.module_device = device
1013915
if not config.compile and getattr(config, "quantization_config", None) is None:
1014916
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
1015-
if self.module_device.type == "cpu":
1016-
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
1017-
elif self.module_device.type == "xpu":
1018-
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
917+
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
918+
1019919
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
1020-
if self.module_device.type == "cpu":
1021-
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
1022-
elif self.module_device.type == "xpu":
1023-
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)
920+
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
1024921

1025922
def forward(
1026923
self,
@@ -1059,11 +956,9 @@ def __init__(self, module, device, config) -> None:
1059956
self.c_proj_linear.bias = self.c_proj.bias
1060957
if self.module_device.type == "cpu":
1061958
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)
1062-
if self.c_proj_linear not in ["LinearAllreduce"]:
1063-
self.linear_add = LinearAdd(self.c_proj_linear)
1064-
elif self.module_device.type == "xpu":
1065-
if self.c_proj_linear not in ["LinearAllreduce"]:
1066-
self.linear_add = XPULinearAdd(self.c_proj_linear)
959+
960+
if self.c_proj_linear not in ["LinearAllreduce"]:
961+
self.linear_add = LinearAdd(self.c_proj_linear)
1067962

1068963
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
1069964
if hasattr(self, "linear_new_gelu"):
@@ -1232,11 +1127,9 @@ def __init__(self, module, device, config):
12321127
super().__init__()
12331128
_setattr_from_module(self, module)
12341129
self.module_device = device
1130+
12351131
if not config.compile and getattr(config, "quantization_config", None) is None:
1236-
if self.module_device.type == "cpu":
1237-
self.linear_gelu = LinearGelu(module.dense)
1238-
elif self.module_device.type == "xpu":
1239-
self.linear_gelu = XPULinearGelu(module.dense)
1132+
self.linear_gelu = LinearGelu(module.dense)
12401133

12411134
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12421135
if hasattr(self, "linear_gelu"):

tests/ipex/test_modeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,10 @@ def test_compare_with_and_without_past_key_values(self):
436436
@parameterized.expand(IPEX_PATCHED_SUPPORTED_ARCHITECTURES)
437437
def test_patched_model(self, model_arch):
438438
model_id = MODEL_NAMES[model_arch]
439+
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
439440
patched_model_id = MODEL_NAMES["patched_" + model_arch]
440-
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, device_map=DEVICE)
441-
exported_model = IPEXModelForCausalLM.from_pretrained(patched_model_id, device_map=DEVICE)
441+
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=dtype, device_map=DEVICE)
442+
exported_model = IPEXModelForCausalLM.from_pretrained(patched_model_id, torch_dtype=dtype, device_map=DEVICE)
442443
tokenizer = AutoTokenizer.from_pretrained(model_id)
443444
tokens = tokenizer("This is a sample", return_tensors="pt").to(DEVICE)
444445
ipex_outputs = ipex_model.generate(

0 commit comments

Comments
 (0)