11
11
ViTForImageClassification_pipeline_forward ,
12
12
ViTForMaskedImageModeling_pipeline_forward ,
13
13
ViTModel_pipeline_forward ,
14
+ get_jit_fused_vit_intermediate_forward ,
14
15
get_jit_fused_vit_output_forward ,
15
16
get_vit_flash_self_attention_forward ,
16
17
)
@@ -24,10 +25,17 @@ def config_sanity_check(self):
24
25
pass
25
26
26
27
def preprocess (self ):
28
+ self .enable_bias_gelu_fused = self .shard_config .enable_jit_fused and self .model .config .hidden_act == "gelu"
27
29
return self .model
28
30
29
31
def module_policy (self ) -> Dict [Union [str , nn .Module ], ModulePolicyDescription ]:
30
- from transformers .models .vit .modeling_vit import ViTEmbeddings , ViTLayer , ViTOutput , ViTSelfAttention
32
+ from transformers .models .vit .modeling_vit import (
33
+ ViTEmbeddings ,
34
+ ViTIntermediate ,
35
+ ViTLayer ,
36
+ ViTOutput ,
37
+ ViTSelfAttention ,
38
+ )
31
39
32
40
policy = {}
33
41
@@ -83,6 +91,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
83
91
SubModuleReplacementDescription (
84
92
suffix = "intermediate.dense" ,
85
93
target_module = col_nn .Linear1D_Col ,
94
+ kwargs = {
95
+ "skip_bias_add" : self .enable_bias_gelu_fused ,
96
+ },
86
97
),
87
98
SubModuleReplacementDescription (
88
99
suffix = "output.dense" ,
@@ -94,6 +105,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
94
105
),
95
106
],
96
107
)
108
+ if self .enable_bias_gelu_fused :
109
+ self .append_or_create_method_replacement (
110
+ description = {
111
+ "forward" : get_jit_fused_vit_intermediate_forward (),
112
+ },
113
+ policy = policy ,
114
+ target_key = ViTIntermediate ,
115
+ )
97
116
98
117
# use flash attention
99
118
if self .shard_config .enable_flash_attention :
@@ -115,6 +134,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
115
134
policy = policy ,
116
135
target_key = ViTOutput ,
117
136
)
137
+
118
138
return policy
119
139
120
140
def new_model_class (self ):
0 commit comments