Skip to content

Commit 550f8f8

Browse files
[autoparallel] integrate_gpt_related_tests (#2134)
* [autoparallel] integrate_gpt_related_tests * polish code * polish code * add GPT2Model into runtime test
1 parent 59e3433 commit 550f8f8

File tree

5 files changed

+221
-211
lines changed

5 files changed

+221
-211
lines changed

colossalai/auto_parallel/passes/runtime_preparation_pass.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,12 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
230230
new_slice_items = []
231231

232232
for slice_item in getitem_index:
233+
if slice_item is None:
234+
new_slice_items.append(None)
235+
continue
236+
233237
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
238+
234239
if slice_item.start in node_pairs:
235240
new_start = node_pairs[slice_item.start]
236241
elif slice_item.stop in node_pairs:
@@ -355,7 +360,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
355360
for node in nodes:
356361
if node.op == 'call_module':
357362
target_module = node.graph.owning_module.get_submodule(node.target)
358-
363+
# TODO: we need to do more actions to take care of the shared parameters.
364+
if hasattr(target_module, 'processed') and target_module.processed:
365+
continue
366+
setattr(target_module, 'processed', True)
359367
for name, param in target_module.named_parameters():
360368
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
361369
# apply the sharding spec of parameters
@@ -404,7 +412,9 @@ def hook_fn(grad):
404412
target_module = root
405413
target = getattr(root, atoms[0])
406414
else:
407-
target_module = root.get_submodule(atoms[-2])
415+
target_module = root
416+
for atom in atoms[:-1]:
417+
target_module = getattr(target_module, atom)
408418
target = getattr(target_module, atoms[-1])
409419

410420
target_sharding_spec = node.sharding_spec

tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py

Whitespace-only changes.

tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py renamed to tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py

Lines changed: 30 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,30 @@
22

33
import torch
44
import torch.nn as nn
5-
import transformers
6-
from torch.fx import GraphModule
7-
from transformers.models.gpt2.modeling_gpt2 import (
8-
GPT2MLP,
9-
BaseModelOutputWithPastAndCrossAttentions,
10-
GPT2PreTrainedModel,
11-
)
5+
from transformers.activations import ACT2FN
6+
from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel
127
from transformers.pytorch_utils import Conv1D
138

14-
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
15-
from colossalai.auto_parallel.tensor_shard.solver import (
16-
CostGraph,
17-
GraphAnalyser,
18-
Solver,
19-
SolverOptions,
20-
StrategiesConstructor,
21-
)
22-
from colossalai.device.device_mesh import DeviceMesh
23-
from colossalai.fx.tracer.tracer import ColoTracer
24-
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
25-
from colossalai.testing import parameterize
26-
from colossalai.testing.pytest_wrapper import run_on_environment_flag
27-
28-
BATCH_SIZE = 1
29-
SEQ_LENGTH = 32
30-
HIDDEN_DIM = 768
9+
10+
class GPT2MLP(nn.Module):
11+
12+
def __init__(self, intermediate_size, config):
13+
super().__init__()
14+
embed_dim = config.hidden_size
15+
self.c_fc = Conv1D(intermediate_size, embed_dim)
16+
self.c_proj = Conv1D(embed_dim, intermediate_size)
17+
self.act = ACT2FN[config.activation_function]
18+
# We temporarily banned the Dropout layer because the rng state need
19+
# to process to get the correct result.
20+
# self.dropout = nn.Dropout(config.resid_pdrop)
21+
22+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
23+
hidden_states = self.c_fc(hidden_states)
24+
hidden_states = self.act(hidden_states)
25+
hidden_states = self.c_proj(hidden_states)
26+
# TODO: the rng state need to be fixed for distributed runtime
27+
# hidden_states = self.dropout(hidden_states)
28+
return hidden_states
3129

3230

3331
# The reason Why we don't import GPT2Attention from transformers directly is that:
@@ -89,7 +87,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
8987

9088
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
9189
attn_weights = attn_weights.type(value.dtype)
92-
attn_weights = self.attn_dropout(attn_weights)
90+
# attn_weights = self.attn_dropout(attn_weights)
9391

9492
# Mask heads if we want to
9593
if head_mask is not None:
@@ -125,15 +123,10 @@ def forward(
125123
present = (key, value)
126124

127125
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
128-
129126
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
130127
attn_output = self.c_proj(attn_output)
131-
attn_output = self.resid_dropout(attn_output)
132-
133-
outputs = (attn_output, present)
134-
outputs += (attn_weights,)
135-
136-
return outputs # a, present, (attentions)
128+
# attn_output = self.resid_dropout(attn_output)
129+
return attn_output
137130

138131

139132
class GPT2Block(nn.Module):
@@ -161,19 +154,15 @@ def forward(
161154
attention_mask=attention_mask,
162155
head_mask=head_mask,
163156
)
164-
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
165-
outputs = attn_outputs[1:]
166157
# residual connection
167-
hidden_states = attn_output + residual
158+
hidden_states = attn_outputs + residual
168159
residual = hidden_states
169160
hidden_states = self.ln_2(hidden_states)
170161
feed_forward_hidden_states = self.mlp(hidden_states)
171162
# residual connection
172163
hidden_states = residual + feed_forward_hidden_states
173164

174-
outputs = (hidden_states,) + outputs[1:]
175-
176-
return outputs # hidden_states, present, (attentions, cross_attentions)
165+
return hidden_states
177166

178167

179168
class GPT2Model(GPT2PreTrainedModel):
@@ -228,103 +217,25 @@ def forward(
228217
# attention_probs has shape bsz x n_heads x N x N
229218
# head_mask has shape n_layer x batch x n_heads x N x N
230219
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
231-
232220
inputs_embeds = self.wte(input_ids)
233221
position_embeds = self.wpe(position_ids)
222+
234223
# add_2
235224
hidden_states = inputs_embeds + position_embeds
236225

237226
token_type_embeds = self.wte(token_type_ids)
238227
hidden_states = hidden_states + token_type_embeds
239228

240-
# transformer_drop
241-
hidden_states = self.drop(hidden_states)
242229
# comment to run pipeline
243230
# add_3
244231
output_shape = input_shape + (hidden_states.size(-1),)
245232

246-
presents = None
247-
all_self_attentions = None
248-
all_cross_attentions = None
249-
all_hidden_states = None
250233
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
251234
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
252-
hidden_states = outputs[0]
235+
hidden_states = outputs
253236

254237
hidden_states = self.ln_f(hidden_states)
255238
# comment to run pipeline
256239
hidden_states = hidden_states.view(output_shape)
257240

258-
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
259-
if v is not None)
260-
261-
262-
@run_on_environment_flag(name='AUTO_PARALLEL')
263-
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
264-
def test_self_attention_block(model_cls):
265-
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
266-
if model_cls == GPT2MLP:
267-
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
268-
else:
269-
model = model_cls(config=config)
270-
physical_mesh_id = torch.arange(0, 4)
271-
mesh_shape = (2, 2)
272-
# [[0, 1]
273-
# [2, 3]]
274-
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
275-
shape_consistency_manager = ShapeConsistencyManager()
276-
277-
tracer = ColoTracer()
278-
if model_cls == GPT2MLP:
279-
input_sample = {
280-
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
281-
}
282-
elif model_cls in (GPT2Attention, GPT2Block):
283-
input_sample = {
284-
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
285-
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
286-
}
287-
else:
288-
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
289-
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
290-
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
291-
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
292-
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
293-
294-
graph = tracer.trace(root=model, meta_args=input_sample)
295-
296-
gm = GraphModule(model, graph, model.__class__.__name__)
297-
print(gm.graph)
298-
gm.recompile()
299-
graph_analyser = GraphAnalyser(gm)
300-
liveness_list = graph_analyser.liveness_analysis()
301-
solver_options = SolverOptions()
302-
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
303-
strategies_constructor.build_strategies_and_cost()
304-
305-
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
306-
cost_graph.simplify_graph()
307-
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
308-
ret = solver.call_solver_serialized_args()
309-
strategies_list = solver.last_s_val
310-
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
311-
312-
computation_cost = 0
313-
communication_cost = 0
314-
memory_cost = 0
315-
for index, node in enumerate(nodes):
316-
print(node.name, node.strategies_vector[strategies_list[index]].name)
317-
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
318-
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
319-
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
320-
if isinstance(node_memory_cost, tuple):
321-
node_memory_cost = node_memory_cost[0]
322-
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
323-
324-
print(f'computation cost is {computation_cost}')
325-
print(f'communication cost is {communication_cost}')
326-
print(f'memory cost is {memory_cost}')
327-
328-
329-
if __name__ == '__main__':
330-
test_self_attention_block()
241+
return hidden_states

0 commit comments

Comments
 (0)