Skip to content

Commit ca56b93

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 108ddfb commit ca56b93

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

colossalai/shardformer/modeling/opt.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from colossalai.pipeline.stage_manager import PipelineStageManager
2222
from colossalai.shardformer.layer import ColoAttention
2323
from colossalai.shardformer.shard import ShardConfig
24+
2425
from ..layer import cross_entropy_1d
26+
2527
logger = logging.get_logger(__name__)
2628

2729

@@ -351,7 +353,7 @@ def opt_for_causal_lm_forward(
351353
loss_fct = CrossEntropyLoss()
352354
shift_logits = shift_logits.view(-1, self.config.vocab_size)
353355
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
354-
356+
355357
if not return_dict:
356358
output = (logits,) + outputs[1:]
357359
return (loss,) + output if loss is not None else output
@@ -987,8 +989,8 @@ def forward(
987989
process_group=shard_config.tensor_parallel_process_group,
988990
vocab_size=self.lm_head.out_features,
989991
)
990-
#loss_fct = CrossEntropyLoss()
991-
#loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
992+
# loss_fct = CrossEntropyLoss()
993+
# loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
992994

993995
if not return_dict:
994996
output = (logits,) + outputs[1:]
@@ -1002,4 +1004,4 @@ def forward(
10021004
attentions=outputs.attentions,
10031005
)
10041006

1005-
return forward
1007+
return forward

colossalai/shardformer/policies/opt.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from ..modeling.opt import (
2222
OPTPipelineForwards,
2323
get_jit_fused_opt_decoder_layer_forward,
24+
get_lm_forward_with_dist_cross_entropy,
2425
get_opt_decoder_forward_for_flash_attention,
2526
get_opt_flash_attention_forward,
26-
get_lm_forward_with_dist_cross_entropy
2727
)
2828
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
2929

@@ -270,21 +270,17 @@ def module_policy(self):
270270
suffix="lm_head",
271271
target_module=VocabParallelLMHead1D,
272272
kwargs=dict(
273-
gather_output=not self.shard_config.parallel_output,
274-
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
273+
gather_output=not self.shard_config.parallel_output,
274+
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
275275
),
276276
),
277277
policy=policy,
278278
target_key=OPTForCausalLM,
279279
)
280280
if self.shard_config.parallel_output:
281-
method_replacement = {
282-
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
283-
}
281+
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
284282
self.append_or_create_method_replacement(
285-
description=method_replacement,
286-
policy=policy,
287-
target_key=OPTForCausalLM
283+
description=method_replacement, policy=policy, target_key=OPTForCausalLM
288284
)
289285
else:
290286
self.append_or_create_submodule_replacement(

0 commit comments

Comments
 (0)