Skip to content

Commit 6af6d6f

Browse files
authored
[shardformer] support bias_gelu_jit_fused for models (#5647)
* support gelu_bias_fused for gpt2 * support gelu_bias_fused for gpt2 fix fix fix * fix fix * fix
1 parent 7f8b166 commit 6af6d6f

File tree

8 files changed

+115
-2
lines changed

8 files changed

+115
-2
lines changed

colossalai/shardformer/modeling/bert.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,3 +1287,16 @@ def forward(
12871287
)
12881288

12891289
return forward
1290+
1291+
1292+
def get_jit_fused_bert_intermediate_forward():
1293+
from transformers.models.bert.modeling_bert import BertIntermediate
1294+
1295+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
1296+
1297+
def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
1298+
hidden_states, bias = self.dense(hidden_states)
1299+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
1300+
return hidden_states
1301+
1302+
return forward

colossalai/shardformer/modeling/blip2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,17 @@ def forward(
129129
return hidden_states
130130

131131
return forward
132+
133+
134+
def get_jit_fused_blip2_mlp_forward():
135+
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
136+
137+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
138+
139+
def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
140+
hidden_states, bias = self.fc1(hidden_states)
141+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
142+
hidden_states = self.fc2(hidden_states)
143+
return hidden_states
144+
145+
return forward

colossalai/shardformer/modeling/gpt2.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,3 +1310,18 @@ def forward(
13101310
)
13111311

13121312
return forward
1313+
1314+
1315+
def get_jit_fused_gpt2_mlp_forward():
1316+
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
1317+
1318+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
1319+
1320+
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
1321+
hidden_states, bias = self.c_fc(hidden_states)
1322+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
1323+
hidden_states = self.c_proj(hidden_states)
1324+
hidden_states = self.dropout(hidden_states)
1325+
return hidden_states
1326+
1327+
return forward

colossalai/shardformer/modeling/vit.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,15 @@ def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Te
372372
return hidden_states
373373

374374
return forward
375+
376+
377+
def get_jit_fused_vit_intermediate_forward():
378+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
379+
380+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
381+
hidden_states, bias = self.dense(hidden_states)
382+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
383+
384+
return hidden_states
385+
386+
return forward

colossalai/shardformer/policies/bert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
BertPipelineForwards,
1313
bert_sequence_parallel_forward_fn,
1414
get_bert_flash_attention_forward,
15+
get_jit_fused_bert_intermediate_forward,
1516
get_jit_fused_bert_output_forward,
1617
get_jit_fused_bert_self_output_forward,
1718
)
@@ -38,11 +39,13 @@ def config_sanity_check(self):
3839

3940
def preprocess(self):
4041
self.tie_weight = self.tie_weight_check()
42+
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
4143
return self.model
4244

4345
def module_policy(self):
4446
from transformers.models.bert.modeling_bert import (
4547
BertEmbeddings,
48+
BertIntermediate,
4649
BertLayer,
4750
BertModel,
4851
BertOutput,
@@ -131,6 +134,7 @@ def module_policy(self):
131134
kwargs={
132135
"seq_parallel_mode": sp_mode,
133136
"overlap": overlap,
137+
"skip_bias_add": self.enable_bias_gelu_fused,
134138
},
135139
),
136140
SubModuleReplacementDescription(
@@ -153,6 +157,14 @@ def module_policy(self):
153157
),
154158
]
155159
)
160+
if self.enable_bias_gelu_fused:
161+
self.append_or_create_method_replacement(
162+
description={
163+
"forward": get_jit_fused_bert_intermediate_forward(),
164+
},
165+
policy=policy,
166+
target_key=BertIntermediate,
167+
)
156168

157169
if sp_mode == "split_gather":
158170
self.append_or_create_method_replacement(

colossalai/shardformer/policies/blip2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..modeling.blip2 import (
44
forward_fn,
55
get_blip2_flash_attention_forward,
6+
get_jit_fused_blip2_mlp_forward,
67
get_jit_fused_blip2_QFormer_output_forward,
78
get_jit_fused_blip2_QFormer_self_output_forward,
89
)
@@ -18,12 +19,16 @@ def config_sanity_check(self):
1819

1920
def preprocess(self):
2021
self.tie_weight = self.tie_weight_check()
22+
self.enable_bias_gelu_fused = (
23+
self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
24+
)
2125
return self.model
2226

2327
def module_policy(self):
2428
from transformers.models.blip_2.modeling_blip_2 import (
2529
Blip2Attention,
2630
Blip2EncoderLayer,
31+
Blip2MLP,
2732
Blip2QFormerLayer,
2833
Blip2QFormerModel,
2934
Blip2QFormerOutput,
@@ -73,6 +78,7 @@ def module_policy(self):
7378
SubModuleReplacementDescription(
7479
suffix="mlp.fc1",
7580
target_module=col_nn.Linear1D_Col,
81+
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
7682
),
7783
SubModuleReplacementDescription(
7884
suffix="mlp.fc2",
@@ -201,6 +207,14 @@ def module_policy(self):
201207
)
202208

203209
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
210+
if self.enable_bias_gelu_fused:
211+
self.append_or_create_method_replacement(
212+
description={
213+
"forward": get_jit_fused_blip2_mlp_forward(),
214+
},
215+
policy=policy,
216+
target_key=Blip2MLP,
217+
)
204218

205219
if embedding_cls is not None:
206220
self.append_or_create_submodule_replacement(

colossalai/shardformer/policies/gpt2.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
GPT2PipelineForwards,
1111
get_gpt2_flash_attention_forward,
1212
get_gpt_model_forward_for_flash_attn,
13+
get_jit_fused_gpt2_mlp_forward,
1314
get_lm_forward_with_dist_cross_entropy,
1415
gpt2_sequence_parallel_forward_fn,
1516
)
@@ -36,10 +37,13 @@ def preprocess(self):
3637
"""
3738
self.tie_weight = self.tie_weight_check()
3839
self.origin_attn_implement = self.model.config._attn_implementation
40+
self.enable_bias_gelu_fused = (
41+
self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu"
42+
)
3943
return self.model
4044

4145
def module_policy(self):
42-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
46+
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
4347

4448
ATTN_IMPLEMENTATION = {
4549
"eager": GPT2Attention,
@@ -119,6 +123,7 @@ def module_policy(self):
119123
"n_fused": 1,
120124
"seq_parallel_mode": sp_mode,
121125
"overlap": overlap,
126+
"skip_bias_add": self.enable_bias_gelu_fused,
122127
},
123128
),
124129
SubModuleReplacementDescription(
@@ -142,6 +147,14 @@ def module_policy(self):
142147
),
143148
],
144149
)
150+
if self.enable_bias_gelu_fused:
151+
self.append_or_create_method_replacement(
152+
description={
153+
"forward": get_jit_fused_gpt2_mlp_forward(),
154+
},
155+
policy=policy,
156+
target_key=GPT2MLP,
157+
)
145158
if embedding_cls is not None:
146159
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
147160
self.append_or_create_submodule_replacement(

colossalai/shardformer/policies/vit.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ViTForImageClassification_pipeline_forward,
1212
ViTForMaskedImageModeling_pipeline_forward,
1313
ViTModel_pipeline_forward,
14+
get_jit_fused_vit_intermediate_forward,
1415
get_jit_fused_vit_output_forward,
1516
get_vit_flash_self_attention_forward,
1617
)
@@ -24,10 +25,17 @@ def config_sanity_check(self):
2425
pass
2526

2627
def preprocess(self):
28+
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
2729
return self.model
2830

2931
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
30-
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention
32+
from transformers.models.vit.modeling_vit import (
33+
ViTEmbeddings,
34+
ViTIntermediate,
35+
ViTLayer,
36+
ViTOutput,
37+
ViTSelfAttention,
38+
)
3139

3240
policy = {}
3341

@@ -83,6 +91,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
8391
SubModuleReplacementDescription(
8492
suffix="intermediate.dense",
8593
target_module=col_nn.Linear1D_Col,
94+
kwargs={
95+
"skip_bias_add": self.enable_bias_gelu_fused,
96+
},
8697
),
8798
SubModuleReplacementDescription(
8899
suffix="output.dense",
@@ -94,6 +105,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
94105
),
95106
],
96107
)
108+
if self.enable_bias_gelu_fused:
109+
self.append_or_create_method_replacement(
110+
description={
111+
"forward": get_jit_fused_vit_intermediate_forward(),
112+
},
113+
policy=policy,
114+
target_key=ViTIntermediate,
115+
)
97116

98117
# use flash attention
99118
if self.shard_config.enable_flash_attention:
@@ -115,6 +134,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
115134
policy=policy,
116135
target_key=ViTOutput,
117136
)
137+
118138
return policy
119139

120140
def new_model_class(self):

0 commit comments

Comments
 (0)