@@ -48,10 +48,14 @@ class SwitchTransformersConfig(PreTrainedConfig):
4848 Number of dense hidden layers in the Transformer encoder layer.
4949 num_sparse_encoder_layers (`int`, *optional*, defaults to 3):
5050 Number of sparse (MoE) dense hidden layers in the Transformer encoder layer.
51+ Note: When set to 0 with `num_layers=1`, the current implementation may still create a sparse layer
52+ due to the sparse step calculation. This edge case is not encountered in existing checkpoints.
5153 num_decoder_layers (`int`, *optional*, defaults to 12):
5254 Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
5355 num_sparse_decoder_layers (`int`, *optional*, defaults to 3):
5456 Number of sparse (MoE) dense hidden layers in the Transformer decoder layer.
57+ Note: When set to 0 with `num_decoder_layers=1`, the current implementation may still create a sparse
58+ layer due to the sparse step calculation. This edge case is not encountered in existing checkpoints.
5559 num_heads (`int`, *optional*, defaults to 12):
5660 Number of attention heads for each attention layer in the Transformer encoder.
5761 num_experts (`int`, *optional*, defaults to 8):
@@ -148,13 +152,13 @@ def __init__(
148152 if self .num_sparse_encoder_layers > 0 :
149153 self .encoder_sparse_step = self .num_layers // self .num_sparse_encoder_layers
150154 else :
151- self .encoder_sparse_step = 0 # 0 means no sparse layers (modeling code checks sparse_step > 0)
155+ self .encoder_sparse_step = self . num_layers # HACK: this will create 0 sparse layers
152156
153157 # This tells us, each how many decoder layer we'll have to set a sparse layer.
154158 if self .num_sparse_decoder_layers > 0 :
155159 self .decoder_sparse_step = self .num_decoder_layers // self .num_sparse_decoder_layers
156160 else :
157- self .decoder_sparse_step = 0 # 0 means no sparse layers (modeling code checks sparse_step > 0)
161+ self .decoder_sparse_step = self . num_decoder_layers # HACK: this will create 0 sparse layers
158162
159163 self .num_heads = num_heads
160164 self .num_experts = num_experts
0 commit comments