Skip to content

Commit 4ed349c

Browse files
committed
update tests
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 238fdc3 commit 4ed349c

File tree

3 files changed

+31
-96
lines changed

3 files changed

+31
-96
lines changed

tests/unit_tests/conftest.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
import pytest
2121

22-
23-
# from tests.unit_tests.download_unit_tests_dataset import get_oldest_release_and_assets
22+
from tests.unit_tests.download_unit_tests_dataset import get_oldest_release_and_assets
2423

2524

2625
logging.basicConfig(level=logging.INFO)
@@ -42,34 +41,34 @@ def cleanup_local_folder():
4241
rmtree("./nemo_experiments", ignore_errors=True)
4342

4443

45-
# @pytest.fixture(scope="session", autouse=True)
46-
# def ensure_test_data(tmp_path_factory):
47-
# """Ensure test data is available in a temporary directory by downloading if necessary."""
48-
# data_path = tmp_path_factory.mktemp("test_data")
44+
@pytest.fixture(scope="session", autouse=True)
45+
def ensure_test_data(tmp_path_factory):
46+
"""Ensure test data is available in a temporary directory by downloading if necessary."""
47+
data_path = tmp_path_factory.mktemp("test_data")
4948

50-
# # Check if data directory exists and has content
51-
# if not any(data_path.iterdir()):
52-
# logger.info(f"Test data not found at {data_path}. Downloading...")
49+
# Check if data directory exists and has content
50+
if not any(data_path.iterdir()):
51+
logger.info(f"Test data not found at {data_path}. Downloading...")
5352

54-
# try:
55-
# # Download assets to data_path
56-
# get_oldest_release_and_assets(assets_dir=str(data_path))
53+
try:
54+
# Download assets to data_path
55+
get_oldest_release_and_assets(assets_dir=str(data_path))
5756

58-
# logger.info("Test data downloaded successfully.")
57+
logger.info("Test data downloaded successfully.")
5958

60-
# except ImportError as e:
61-
# logger.info(f"Failed to import download function: {e}")
62-
# except ValueError as e:
63-
# logger.error(e)
64-
# pytest.exit(f"Failed to download test data: {e}", returncode=1)
65-
# # Don't fail the tests, just warn
66-
# except Exception as e:
67-
# logger.info(f"Failed to download test data: {e}")
68-
# # Don't fail the tests, just warn
69-
# else:
70-
# logger.info(f"Test data already available at {data_path}")
59+
except ImportError as e:
60+
logger.info(f"Failed to import download function: {e}")
61+
except ValueError as e:
62+
logger.error(e)
63+
pytest.exit(f"Failed to download test data: {e}", returncode=1)
64+
# Don't fail the tests, just warn
65+
except Exception as e:
66+
logger.info(f"Failed to download test data: {e}")
67+
# Don't fail the tests, just warn
68+
else:
69+
logger.info(f"Test data already available at {data_path}")
7170

72-
# yield data_path
71+
yield data_path
7372

7473

7574
@pytest.fixture(autouse=True)

tests/unit_tests/models/starcoder/test_starcoder2_provider.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_starcoder2_model_provider_defaults(self):
3838
assert provider.hidden_size == 768
3939
assert provider.num_attention_heads == 12
4040

41-
# Check Starcoder2-specific defaults
41+
# Check Starcoder2-specific defaults + transformer config post init
4242
assert provider.normalization == "LayerNorm"
4343
assert provider.activation_func == F.gelu
4444
assert provider.add_bias_linear is True
@@ -49,28 +49,14 @@ def test_starcoder2_model_provider_defaults(self):
4949
assert provider.attention_dropout == 0.0
5050
assert provider.init_method_std == 0.01
5151
assert provider.share_embeddings_and_output_weights is False
52-
assert provider.kv_channels is None
53-
assert provider.num_query_groups is None
52+
assert provider.kv_channels == 64
53+
assert provider.num_query_groups == 12
5454
assert provider.window_size is None
5555
assert provider.attention_softmax_in_fp32 is True
5656
assert provider.bias_activation_fusion is True
5757
assert provider.bias_dropout_fusion is True
5858
assert provider.layernorm_epsilon == 1e-5
5959

60-
def test_starcoder2_model_provider_inheritance(self):
61-
"""Test Starcoder2ModelProvider inherits from GPTModelProvider."""
62-
from megatron.bridge.models.gpt_provider import GPTModelProvider
63-
64-
provider = Starcoder2ModelProvider(
65-
num_layers=12,
66-
hidden_size=768,
67-
num_attention_heads=12,
68-
)
69-
70-
assert isinstance(provider, GPTModelProvider)
71-
assert hasattr(provider, "provide")
72-
assert callable(provider.provide)
73-
7460

7561
class TestStarcoder2ModelProvider3B:
7662
"""Test cases for Starcoder2ModelProvider3B class."""
@@ -98,21 +84,12 @@ def test_starcoder2_3b_defaults(self):
9884
assert provider.hidden_dropout == 0.0
9985
assert provider.attention_dropout == 0.0
10086
assert provider.share_embeddings_and_output_weights is False
101-
assert provider.kv_channels is None
10287
assert provider.window_size is None
10388
assert provider.attention_softmax_in_fp32 is True
10489
assert provider.bias_activation_fusion is True
10590
assert provider.bias_dropout_fusion is True
10691
assert provider.layernorm_epsilon == 1e-5
10792

108-
def test_starcoder2_3b_inheritance(self):
109-
"""Test Starcoder2ModelProvider3B inherits from Starcoder2ModelProvider."""
110-
provider = Starcoder2ModelProvider3B()
111-
112-
assert isinstance(provider, Starcoder2ModelProvider)
113-
assert hasattr(provider, "provide")
114-
assert callable(provider.provide)
115-
11693

11794
class TestStarcoder2ModelProvider7B:
11895
"""Test cases for Starcoder2ModelProvider7B class."""
@@ -140,21 +117,13 @@ def test_starcoder2_7b_defaults(self):
140117
assert provider.hidden_dropout == 0.0
141118
assert provider.attention_dropout == 0.0
142119
assert provider.share_embeddings_and_output_weights is False
143-
assert provider.kv_channels is None
120+
assert provider.kv_channels is 128
144121
assert provider.window_size is None
145122
assert provider.attention_softmax_in_fp32 is True
146123
assert provider.bias_activation_fusion is True
147124
assert provider.bias_dropout_fusion is True
148125
assert provider.layernorm_epsilon == 1e-5
149126

150-
def test_starcoder2_7b_inheritance(self):
151-
"""Test Starcoder2ModelProvider7B inherits from Starcoder2ModelProvider."""
152-
provider = Starcoder2ModelProvider7B()
153-
154-
assert isinstance(provider, Starcoder2ModelProvider)
155-
assert hasattr(provider, "provide")
156-
assert callable(provider.provide)
157-
158127

159128
class TestStarcoder2ModelProvider15B:
160129
"""Test cases for Starcoder2ModelProvider15B class."""
@@ -182,21 +151,13 @@ def test_starcoder2_15b_defaults(self):
182151
assert provider.hidden_dropout == 0.0
183152
assert provider.attention_dropout == 0.0
184153
assert provider.share_embeddings_and_output_weights is False
185-
assert provider.kv_channels is None
154+
assert provider.kv_channels == 128
186155
assert provider.window_size is None
187156
assert provider.attention_softmax_in_fp32 is True
188157
assert provider.bias_activation_fusion is True
189158
assert provider.bias_dropout_fusion is True
190159
assert provider.layernorm_epsilon == 1e-5
191160

192-
def test_starcoder2_15b_inheritance(self):
193-
"""Test Starcoder2ModelProvider15B inherits from Starcoder2ModelProvider."""
194-
provider = Starcoder2ModelProvider15B()
195-
196-
assert isinstance(provider, Starcoder2ModelProvider)
197-
assert hasattr(provider, "provide")
198-
assert callable(provider.provide)
199-
200161

201162
class TestStarcoder2ProviderInheritance:
202163
"""Test inheritance relationships between Starcoder2 providers."""

tests/unit_tests/models/starcoder/test_starcoder_provider.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,12 @@ def test_starcoder_model_provider_defaults(self):
4747
assert provider.init_method_std == 0.01
4848
assert provider.layernorm_epsilon == 1e-5
4949
assert provider.share_embeddings_and_output_weights is False
50-
assert provider.kv_channels is None
50+
assert provider.kv_channels == 64
5151
assert provider.num_query_groups == 1
5252
assert provider.attention_softmax_in_fp32 is True
5353
assert provider.bias_activation_fusion is True
5454
assert provider.bias_dropout_fusion is True
5555

56-
def test_starcoder_model_provider_inheritance(self):
57-
"""Test StarcoderModelProvider inherits from GPTModelProvider."""
58-
from megatron.bridge.models.gpt_provider import GPTModelProvider
59-
60-
provider = StarcoderModelProvider(
61-
num_layers=12,
62-
hidden_size=768,
63-
num_attention_heads=12,
64-
)
65-
66-
assert isinstance(provider, GPTModelProvider)
67-
assert hasattr(provider, "provide")
68-
assert callable(provider.provide)
69-
7056

7157
class TestStarcoderConfig15B:
7258
"""Test cases for StarcoderConfig15B class."""
@@ -92,20 +78,12 @@ def test_starcoder_config_15b_defaults(self):
9278
assert provider.attention_dropout == 0.2
9379
assert provider.layernorm_epsilon == 1e-5
9480
assert provider.share_embeddings_and_output_weights is False
95-
assert provider.kv_channels is None
81+
assert provider.kv_channels == 128
9682
assert provider.num_query_groups == 1
9783
assert provider.attention_softmax_in_fp32 is True
9884
assert provider.bias_activation_fusion is True
9985
assert provider.bias_dropout_fusion is True
10086

101-
def test_starcoder_config_15b_inheritance(self):
102-
"""Test StarcoderConfig15B inherits from StarcoderModelProvider."""
103-
provider = StarcoderConfig15B()
104-
105-
assert isinstance(provider, StarcoderModelProvider)
106-
assert hasattr(provider, "provide")
107-
assert callable(provider.provide)
108-
10987

11088
class TestStarcoderProviderInheritance:
11189
"""Test inheritance relationships between Starcoder providers."""
@@ -120,9 +98,6 @@ def test_starcoder_models_inherit_from_gpt(self):
12098

12199
def test_provide_method_inherited(self):
122100
"""Test that provide method works correctly in inherited classes."""
123-
# Test with StarcoderConfig15B
124101
provider = StarcoderConfig15B()
125-
126-
# The provide method should be inherited from GPTModelProvider
127102
assert hasattr(provider, "provide")
128103
assert callable(provider.provide)

0 commit comments

Comments
 (0)