Skip to content

Commit a7727fb

Browse files
harshaljanjaniRocketknight1
authored andcommitted
docs: Add docstring notes for Switch Transformers sparse layer edge case
1 parent 8953826 commit a7727fb

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/transformers/models/switch_transformers/configuration_switch_transformers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,14 @@ class SwitchTransformersConfig(PreTrainedConfig):
4848
Number of dense hidden layers in the Transformer encoder layer.
4949
num_sparse_encoder_layers (`int`, *optional*, defaults to 3):
5050
Number of sparse (MoE) dense hidden layers in the Transformer encoder layer.
51+
Note: When set to 0 with `num_layers=1`, the current implementation may still create a sparse layer
52+
due to the sparse step calculation. This edge case is not encountered in existing checkpoints.
5153
num_decoder_layers (`int`, *optional*, defaults to 12):
5254
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
5355
num_sparse_decoder_layers (`int`, *optional*, defaults to 3):
5456
Number of sparse (MoE) dense hidden layers in the Transformer decoder layer.
57+
Note: When set to 0 with `num_decoder_layers=1`, the current implementation may still create a sparse
58+
layer due to the sparse step calculation. This edge case is not encountered in existing checkpoints.
5559
num_heads (`int`, *optional*, defaults to 12):
5660
Number of attention heads for each attention layer in the Transformer encoder.
5761
num_experts (`int`, *optional*, defaults to 8):
@@ -148,13 +152,13 @@ def __init__(
148152
if self.num_sparse_encoder_layers > 0:
149153
self.encoder_sparse_step = self.num_layers // self.num_sparse_encoder_layers
150154
else:
151-
self.encoder_sparse_step = 0 # 0 means no sparse layers (modeling code checks sparse_step > 0)
155+
self.encoder_sparse_step = self.num_layers # HACK: this will create 0 sparse layers
152156

153157
# This tells us, each how many decoder layer we'll have to set a sparse layer.
154158
if self.num_sparse_decoder_layers > 0:
155159
self.decoder_sparse_step = self.num_decoder_layers // self.num_sparse_decoder_layers
156160
else:
157-
self.decoder_sparse_step = 0 # 0 means no sparse layers (modeling code checks sparse_step > 0)
161+
self.decoder_sparse_step = self.num_decoder_layers # HACK: this will create 0 sparse layers
158162

159163
self.num_heads = num_heads
160164
self.num_experts = num_experts

0 commit comments

Comments
 (0)