|
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | import torch.nn as nn |
5 | | -import transformers |
6 | | -from torch.fx import GraphModule |
7 | | -from transformers.models.gpt2.modeling_gpt2 import ( |
8 | | - GPT2MLP, |
9 | | - BaseModelOutputWithPastAndCrossAttentions, |
10 | | - GPT2PreTrainedModel, |
11 | | -) |
| 5 | +from transformers.activations import ACT2FN |
| 6 | +from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel |
12 | 7 | from transformers.pytorch_utils import Conv1D |
13 | 8 |
|
14 | | -from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP |
15 | | -from colossalai.auto_parallel.tensor_shard.solver import ( |
16 | | - CostGraph, |
17 | | - GraphAnalyser, |
18 | | - Solver, |
19 | | - SolverOptions, |
20 | | - StrategiesConstructor, |
21 | | -) |
22 | | -from colossalai.device.device_mesh import DeviceMesh |
23 | | -from colossalai.fx.tracer.tracer import ColoTracer |
24 | | -from colossalai.tensor.shape_consistency import ShapeConsistencyManager |
25 | | -from colossalai.testing import parameterize |
26 | | -from colossalai.testing.pytest_wrapper import run_on_environment_flag |
27 | | - |
28 | | -BATCH_SIZE = 1 |
29 | | -SEQ_LENGTH = 32 |
30 | | -HIDDEN_DIM = 768 |
| 9 | + |
| 10 | +class GPT2MLP(nn.Module): |
| 11 | + |
| 12 | + def __init__(self, intermediate_size, config): |
| 13 | + super().__init__() |
| 14 | + embed_dim = config.hidden_size |
| 15 | + self.c_fc = Conv1D(intermediate_size, embed_dim) |
| 16 | + self.c_proj = Conv1D(embed_dim, intermediate_size) |
| 17 | + self.act = ACT2FN[config.activation_function] |
| 18 | + # We temporarily banned the Dropout layer because the rng state need |
| 19 | + # to process to get the correct result. |
| 20 | + # self.dropout = nn.Dropout(config.resid_pdrop) |
| 21 | + |
| 22 | + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: |
| 23 | + hidden_states = self.c_fc(hidden_states) |
| 24 | + hidden_states = self.act(hidden_states) |
| 25 | + hidden_states = self.c_proj(hidden_states) |
| 26 | + # TODO: the rng state need to be fixed for distributed runtime |
| 27 | + # hidden_states = self.dropout(hidden_states) |
| 28 | + return hidden_states |
31 | 29 |
|
32 | 30 |
|
33 | 31 | # The reason Why we don't import GPT2Attention from transformers directly is that: |
@@ -89,7 +87,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): |
89 | 87 |
|
90 | 88 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise |
91 | 89 | attn_weights = attn_weights.type(value.dtype) |
92 | | - attn_weights = self.attn_dropout(attn_weights) |
| 90 | + # attn_weights = self.attn_dropout(attn_weights) |
93 | 91 |
|
94 | 92 | # Mask heads if we want to |
95 | 93 | if head_mask is not None: |
@@ -125,15 +123,10 @@ def forward( |
125 | 123 | present = (key, value) |
126 | 124 |
|
127 | 125 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) |
128 | | - |
129 | 126 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) |
130 | 127 | attn_output = self.c_proj(attn_output) |
131 | | - attn_output = self.resid_dropout(attn_output) |
132 | | - |
133 | | - outputs = (attn_output, present) |
134 | | - outputs += (attn_weights,) |
135 | | - |
136 | | - return outputs # a, present, (attentions) |
| 128 | + # attn_output = self.resid_dropout(attn_output) |
| 129 | + return attn_output |
137 | 130 |
|
138 | 131 |
|
139 | 132 | class GPT2Block(nn.Module): |
@@ -161,19 +154,15 @@ def forward( |
161 | 154 | attention_mask=attention_mask, |
162 | 155 | head_mask=head_mask, |
163 | 156 | ) |
164 | | - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) |
165 | | - outputs = attn_outputs[1:] |
166 | 157 | # residual connection |
167 | | - hidden_states = attn_output + residual |
| 158 | + hidden_states = attn_outputs + residual |
168 | 159 | residual = hidden_states |
169 | 160 | hidden_states = self.ln_2(hidden_states) |
170 | 161 | feed_forward_hidden_states = self.mlp(hidden_states) |
171 | 162 | # residual connection |
172 | 163 | hidden_states = residual + feed_forward_hidden_states |
173 | 164 |
|
174 | | - outputs = (hidden_states,) + outputs[1:] |
175 | | - |
176 | | - return outputs # hidden_states, present, (attentions, cross_attentions) |
| 165 | + return hidden_states |
177 | 166 |
|
178 | 167 |
|
179 | 168 | class GPT2Model(GPT2PreTrainedModel): |
@@ -228,103 +217,25 @@ def forward( |
228 | 217 | # attention_probs has shape bsz x n_heads x N x N |
229 | 218 | # head_mask has shape n_layer x batch x n_heads x N x N |
230 | 219 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) |
231 | | - |
232 | 220 | inputs_embeds = self.wte(input_ids) |
233 | 221 | position_embeds = self.wpe(position_ids) |
| 222 | + |
234 | 223 | # add_2 |
235 | 224 | hidden_states = inputs_embeds + position_embeds |
236 | 225 |
|
237 | 226 | token_type_embeds = self.wte(token_type_ids) |
238 | 227 | hidden_states = hidden_states + token_type_embeds |
239 | 228 |
|
240 | | - # transformer_drop |
241 | | - hidden_states = self.drop(hidden_states) |
242 | 229 | # comment to run pipeline |
243 | 230 | # add_3 |
244 | 231 | output_shape = input_shape + (hidden_states.size(-1),) |
245 | 232 |
|
246 | | - presents = None |
247 | | - all_self_attentions = None |
248 | | - all_cross_attentions = None |
249 | | - all_hidden_states = None |
250 | 233 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |
251 | 234 | outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) |
252 | | - hidden_states = outputs[0] |
| 235 | + hidden_states = outputs |
253 | 236 |
|
254 | 237 | hidden_states = self.ln_f(hidden_states) |
255 | 238 | # comment to run pipeline |
256 | 239 | hidden_states = hidden_states.view(output_shape) |
257 | 240 |
|
258 | | - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] |
259 | | - if v is not None) |
260 | | - |
261 | | - |
262 | | -@run_on_environment_flag(name='AUTO_PARALLEL') |
263 | | -@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) |
264 | | -def test_self_attention_block(model_cls): |
265 | | - config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) |
266 | | - if model_cls == GPT2MLP: |
267 | | - model = model_cls(intermediate_size=4 * config.hidden_size, config=config) |
268 | | - else: |
269 | | - model = model_cls(config=config) |
270 | | - physical_mesh_id = torch.arange(0, 4) |
271 | | - mesh_shape = (2, 2) |
272 | | - # [[0, 1] |
273 | | - # [2, 3]] |
274 | | - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) |
275 | | - shape_consistency_manager = ShapeConsistencyManager() |
276 | | - |
277 | | - tracer = ColoTracer() |
278 | | - if model_cls == GPT2MLP: |
279 | | - input_sample = { |
280 | | - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), |
281 | | - } |
282 | | - elif model_cls in (GPT2Attention, GPT2Block): |
283 | | - input_sample = { |
284 | | - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), |
285 | | - 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), |
286 | | - } |
287 | | - else: |
288 | | - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) |
289 | | - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) |
290 | | - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) |
291 | | - kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) |
292 | | - input_sample = {k: v.to('meta') for k, v in kwargs.items()} |
293 | | - |
294 | | - graph = tracer.trace(root=model, meta_args=input_sample) |
295 | | - |
296 | | - gm = GraphModule(model, graph, model.__class__.__name__) |
297 | | - print(gm.graph) |
298 | | - gm.recompile() |
299 | | - graph_analyser = GraphAnalyser(gm) |
300 | | - liveness_list = graph_analyser.liveness_analysis() |
301 | | - solver_options = SolverOptions() |
302 | | - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) |
303 | | - strategies_constructor.build_strategies_and_cost() |
304 | | - |
305 | | - cost_graph = CostGraph(strategies_constructor.leaf_strategies) |
306 | | - cost_graph.simplify_graph() |
307 | | - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) |
308 | | - ret = solver.call_solver_serialized_args() |
309 | | - strategies_list = solver.last_s_val |
310 | | - nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] |
311 | | - |
312 | | - computation_cost = 0 |
313 | | - communication_cost = 0 |
314 | | - memory_cost = 0 |
315 | | - for index, node in enumerate(nodes): |
316 | | - print(node.name, node.strategies_vector[strategies_list[index]].name) |
317 | | - computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total |
318 | | - communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total |
319 | | - node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total |
320 | | - if isinstance(node_memory_cost, tuple): |
321 | | - node_memory_cost = node_memory_cost[0] |
322 | | - memory_cost += node_memory_cost.activation + node_memory_cost.parameter |
323 | | - |
324 | | - print(f'computation cost is {computation_cost}') |
325 | | - print(f'communication cost is {communication_cost}') |
326 | | - print(f'memory cost is {memory_cost}') |
327 | | - |
328 | | - |
329 | | -if __name__ == '__main__': |
330 | | - test_self_attention_block() |
| 241 | + return hidden_states |
0 commit comments