Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions examples/auto_parallel/models/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ def __init__(self, config, ipp: Optional[int] = None):
)

self.config = config
self.reshard_row_and_col = ReshardLayer()

def forward(
self,
Expand Down Expand Up @@ -529,7 +528,6 @@ def forward(
)

attn_output = self.o_proj(attn_output)
attn_output = self.reshard_row_and_col(attn_output)

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -670,7 +668,6 @@ def __init__(self, config, layer_idx=0, ipp=0):
self.residual_add2 = FusedDropoutAdd(
config.hidden_dropout_prob, mode="upscale_in_train"
)
self.reshard_col = ReshardLayer()
self.reshard_replicate = ReshardLayer()

def create_moe_mlp_layer(self, layer_idx, ipp):
Expand Down Expand Up @@ -816,7 +813,6 @@ def forward(
hidden_states, token_type_ids
)
else:
hidden_states = self.reshard_col(hidden_states)
hidden_states = self.mlp(hidden_states)
gate_logits = None

Expand Down Expand Up @@ -1003,7 +999,6 @@ def __init__(self, config: ErnieMoEConfig):
self.inbatch_pack_offset = None
self.token_type_ids = None
self.past_key_values = None
self.inbatch_pack_offset = None
self.inputs_embeds = None
self.all_hidden_states = None
self.all_self_attns = None
Expand Down Expand Up @@ -1287,7 +1282,6 @@ def forward(
self.past_key_values = past_key_values
self.inbatch_pack_offset = inbatch_pack_offset
self.token_type_ids = token_type_ids
self.inbatch_pack_offset = inbatch_pack_offset
if use_cache is not None:
self.config.use_cache = use_cache
if return_dict is not None:
Expand Down Expand Up @@ -1698,12 +1692,6 @@ def auto_dist_config(self, prefix=""):
f"{prefix}ernie.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
f"{prefix}ernie.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
f"{prefix}ernie.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
f"{prefix}ernie.layers.*.self_attn.reshard_row_and_col": PrepareLayerInput(
layer_input_reshard_row_and_col_hook
),
f"{prefix}ernie.layers.*.reshard_col": PrepareLayerInput(
layer_input_reshard_col_hook
),
f"{prefix}ernie.layers.*.reshard_replicate": PrepareLayerInput(
layer_input_reshard_replicate_hook
),
Expand Down Expand Up @@ -1754,46 +1742,6 @@ def forward(self, input):
return input


def layer_input_reshard_row_and_col_hook(process_mesh):
def hook(layer, inputs, output=None):
res_inputs = []
for input in inputs:
if not input.is_dist():
x = dist.shard_tensor(
input, process_mesh, [dist.Shard(0), dist.Shard(1)]
)
res_inputs.append(
dist.reshard(x, process_mesh, [dist.Shard(0), dist.Shard(1)])
)
else:
res_inputs.append(
dist.reshard(input, process_mesh, [dist.Shard(0), dist.Shard(1)])
)
return tuple(res_inputs)

return hook


def layer_input_reshard_col_hook(process_mesh):
def hook(layer, inputs, output=None):
res_inputs = []
for input in inputs:
if not input.is_dist():
x = dist.shard_tensor(
input, process_mesh, [dist.Shard(1), dist.Replicate()]
)
res_inputs.append(
dist.reshard(x, process_mesh, [dist.Shard(1), dist.Replicate()])
)
else:
res_inputs.append(
dist.reshard(input, process_mesh, [dist.Shard(1), dist.Replicate()])
)
return tuple(res_inputs)

return hook


def layer_input_reshard_replicate_hook(process_mesh):
def hook(layer, inputs, output=None):
res_inputs = []
Expand Down
2 changes: 0 additions & 2 deletions examples/auto_parallel/models/modeling_vpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
ErniePretrainedModel,
ErnieModel,
ErniePretrainingCriterion,
ReshardLayer,
ErnieLMHead,
ErnieDecoderLayer,
ErnieAttention,
Expand Down Expand Up @@ -401,7 +400,6 @@ def __init__(self, config: ErnieMoEConfig, pp_layer_idx=None, ipp=0):
self.all_self_attns = None
self.next_decoder_cache = None
self.inputs_embeds_cur_depth_list = None
self.reshard_replicate = ReshardLayer()

def mtp_layer(
self, hidden_states, inputs_embeds_cur_depth_list, attention_mask, position_ids
Expand Down