@@ -38,7 +38,7 @@ def test_starcoder2_model_provider_defaults(self):
3838 assert provider .hidden_size == 768
3939 assert provider .num_attention_heads == 12
4040
41- # Check Starcoder2-specific defaults
41+ # Check Starcoder2-specific defaults + transformer config post init
4242 assert provider .normalization == "LayerNorm"
4343 assert provider .activation_func == F .gelu
4444 assert provider .add_bias_linear is True
@@ -49,28 +49,14 @@ def test_starcoder2_model_provider_defaults(self):
4949 assert provider .attention_dropout == 0.0
5050 assert provider .init_method_std == 0.01
5151 assert provider .share_embeddings_and_output_weights is False
52- assert provider .kv_channels is None
53- assert provider .num_query_groups is None
52+ assert provider .kv_channels == 64
53+ assert provider .num_query_groups == 12
5454 assert provider .window_size is None
5555 assert provider .attention_softmax_in_fp32 is True
5656 assert provider .bias_activation_fusion is True
5757 assert provider .bias_dropout_fusion is True
5858 assert provider .layernorm_epsilon == 1e-5
5959
60- def test_starcoder2_model_provider_inheritance (self ):
61- """Test Starcoder2ModelProvider inherits from GPTModelProvider."""
62- from megatron .bridge .models .gpt_provider import GPTModelProvider
63-
64- provider = Starcoder2ModelProvider (
65- num_layers = 12 ,
66- hidden_size = 768 ,
67- num_attention_heads = 12 ,
68- )
69-
70- assert isinstance (provider , GPTModelProvider )
71- assert hasattr (provider , "provide" )
72- assert callable (provider .provide )
73-
7460
7561class TestStarcoder2ModelProvider3B :
7662 """Test cases for Starcoder2ModelProvider3B class."""
@@ -98,21 +84,12 @@ def test_starcoder2_3b_defaults(self):
9884 assert provider .hidden_dropout == 0.0
9985 assert provider .attention_dropout == 0.0
10086 assert provider .share_embeddings_and_output_weights is False
101- assert provider .kv_channels is None
10287 assert provider .window_size is None
10388 assert provider .attention_softmax_in_fp32 is True
10489 assert provider .bias_activation_fusion is True
10590 assert provider .bias_dropout_fusion is True
10691 assert provider .layernorm_epsilon == 1e-5
10792
108- def test_starcoder2_3b_inheritance (self ):
109- """Test Starcoder2ModelProvider3B inherits from Starcoder2ModelProvider."""
110- provider = Starcoder2ModelProvider3B ()
111-
112- assert isinstance (provider , Starcoder2ModelProvider )
113- assert hasattr (provider , "provide" )
114- assert callable (provider .provide )
115-
11693
11794class TestStarcoder2ModelProvider7B :
11895 """Test cases for Starcoder2ModelProvider7B class."""
@@ -140,21 +117,13 @@ def test_starcoder2_7b_defaults(self):
140117 assert provider .hidden_dropout == 0.0
141118 assert provider .attention_dropout == 0.0
142119 assert provider .share_embeddings_and_output_weights is False
143- assert provider .kv_channels is None
120+ assert provider .kv_channels is 128
144121 assert provider .window_size is None
145122 assert provider .attention_softmax_in_fp32 is True
146123 assert provider .bias_activation_fusion is True
147124 assert provider .bias_dropout_fusion is True
148125 assert provider .layernorm_epsilon == 1e-5
149126
150- def test_starcoder2_7b_inheritance (self ):
151- """Test Starcoder2ModelProvider7B inherits from Starcoder2ModelProvider."""
152- provider = Starcoder2ModelProvider7B ()
153-
154- assert isinstance (provider , Starcoder2ModelProvider )
155- assert hasattr (provider , "provide" )
156- assert callable (provider .provide )
157-
158127
159128class TestStarcoder2ModelProvider15B :
160129 """Test cases for Starcoder2ModelProvider15B class."""
@@ -182,21 +151,13 @@ def test_starcoder2_15b_defaults(self):
182151 assert provider .hidden_dropout == 0.0
183152 assert provider .attention_dropout == 0.0
184153 assert provider .share_embeddings_and_output_weights is False
185- assert provider .kv_channels is None
154+ assert provider .kv_channels == 128
186155 assert provider .window_size is None
187156 assert provider .attention_softmax_in_fp32 is True
188157 assert provider .bias_activation_fusion is True
189158 assert provider .bias_dropout_fusion is True
190159 assert provider .layernorm_epsilon == 1e-5
191160
192- def test_starcoder2_15b_inheritance (self ):
193- """Test Starcoder2ModelProvider15B inherits from Starcoder2ModelProvider."""
194- provider = Starcoder2ModelProvider15B ()
195-
196- assert isinstance (provider , Starcoder2ModelProvider )
197- assert hasattr (provider , "provide" )
198- assert callable (provider .provide )
199-
200161
201162class TestStarcoder2ProviderInheritance :
202163 """Test inheritance relationships between Starcoder2 providers."""
0 commit comments