@@ -54,7 +54,6 @@ def module_policy(self):
54
54
if self .shard_config .enable_sequence_parallelism :
55
55
self .shard_config .enable_sequence_parallelism = False
56
56
warnings .warn ("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag." )
57
- use_sequence_parallel = self .shard_config .enable_sequence_parallelism
58
57
59
58
overlap = self .shard_config .enable_sequence_overlap
60
59
if self .shard_config .enable_tensor_parallelism :
@@ -78,40 +77,34 @@ def module_policy(self):
78
77
suffix = "attn.k_proj" ,
79
78
target_module = col_nn .Linear1D_Col ,
80
79
kwargs = {
81
- "seq_parallel" : use_sequence_parallel ,
82
80
"overlap" : overlap ,
83
81
},
84
82
),
85
83
SubModuleReplacementDescription (
86
84
suffix = "attn.q_proj" ,
87
85
target_module = col_nn .Linear1D_Col ,
88
86
kwargs = {
89
- "seq_parallel" : use_sequence_parallel ,
90
87
"overlap" : overlap ,
91
88
},
92
89
),
93
90
SubModuleReplacementDescription (
94
91
suffix = "attn.v_proj" ,
95
92
target_module = col_nn .Linear1D_Col ,
96
93
kwargs = {
97
- "seq_parallel" : use_sequence_parallel ,
98
94
"overlap" : overlap ,
99
95
},
100
96
),
101
97
SubModuleReplacementDescription (
102
98
suffix = "attn.out_proj" ,
103
99
target_module = col_nn .Linear1D_Row ,
104
- kwargs = {"seq_parallel" : use_sequence_parallel },
105
100
),
106
101
SubModuleReplacementDescription (
107
102
suffix = "mlp.fc_in" ,
108
103
target_module = col_nn .Linear1D_Col ,
109
- kwargs = {"seq_parallel" : use_sequence_parallel },
110
104
),
111
105
SubModuleReplacementDescription (
112
106
suffix = "mlp.fc_out" ,
113
107
target_module = col_nn .Linear1D_Row ,
114
- kwargs = {"seq_parallel" : use_sequence_parallel },
115
108
),
116
109
SubModuleReplacementDescription (
117
110
suffix = "attn.attn_dropout" ,
0 commit comments