Skip to content

Commit 188c426

Browse files
Reverts 69ed0c5
PiperOrigin-RevId: 830610186
1 parent fdbd999 commit 188c426

File tree

5 files changed

+344
-414
lines changed

5 files changed

+344
-414
lines changed

src/MaxText/layers/decoders.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -395,15 +395,9 @@ def get_decoder_layers(self):
395395
return [mixtral.MixtralDecoderLayerToLinen]
396396
case DecoderBlockType.DEEPSEEK:
397397
if self.config.use_batch_split_schedule:
398-
return [
399-
deepseek_batchsplit.DeepSeekDenseLayerToLinen,
400-
deepseek_batchsplit.DeepSeekMoELayerToLinen,
401-
]
398+
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
402399
else:
403-
return [
404-
deepseek.DeepSeekDenseLayerToLinen,
405-
deepseek.DeepSeekMoELayerToLinen,
406-
]
400+
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
407401
case DecoderBlockType.GEMMA:
408402
return [gemma.GemmaDecoderLayerToLinen]
409403
case DecoderBlockType.GEMMA2:

0 commit comments

Comments
 (0)