Skip to content

Commit b437bfa

Browse files
committed
Fix tests
1 parent 3620bfe commit b437bfa

File tree

1 file changed

+33
-17
lines changed

1 file changed

+33
-17
lines changed

label_studio_ml/examples/timeseries_segmenter/tests/test_segmenter.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,22 @@ def temp_model_dir():
7272
def segmenter_instance(temp_model_dir):
7373
"""Create a TimeSeriesSegmenter instance with test configuration."""
7474
logger.info("Creating TimeSeriesSegmenter instance for testing")
75-
with patch.dict(os.environ, {'MODEL_DIR': temp_model_dir, 'TRAIN_EPOCHS': '10', 'SEQUENCE_SIZE': '10'}):
75+
# Patch environment variables for the entire fixture scope
76+
with patch.dict(os.environ, {
77+
'MODEL_DIR': temp_model_dir,
78+
'TRAIN_EPOCHS': '10',
79+
'SEQUENCE_SIZE': '10',
80+
'HIDDEN_SIZE': '32'
81+
}):
7682
segmenter = TimeSeriesSegmenter(
7783
label_config=LABEL_CONFIG
7884
)
85+
# Override class attributes with test values
86+
segmenter.MODEL_DIR = temp_model_dir
87+
segmenter.TRAIN_EPOCHS = 10
88+
segmenter.SEQUENCE_SIZE = 10
89+
segmenter.HIDDEN_SIZE = 32
90+
7991
segmenter.setup()
8092
logger.info("TimeSeriesSegmenter instance created and set up")
8193
yield segmenter
@@ -115,10 +127,10 @@ def make_task_no_annotations():
115127
"annotations": [],
116128
}
117129

118-
def fake_preload(self, task, value=None, read_file=True):
130+
def fake_preload(self, task, path, read_file=True):
119131
"""Mock function to preload CSV data."""
120-
logger.debug(f"Mock preload called with value: {value}")
121-
return open(value).read()
132+
logger.debug(f"Mock preload called with path: {path}")
133+
return open(path).read()
122134

123135
class TestTimeSeriesSegmenter:
124136
"""Test suite for TimeSeriesSegmenter with PyTorch LSTM implementation."""
@@ -461,21 +473,25 @@ def test_model_parameters_configuration(self, temp_model_dir):
461473
"""Test different model parameter configurations."""
462474
logger.info("=== Testing model parameters configuration ===")
463475
configs = [
464-
{"SEQUENCE_SIZE": "5", "HIDDEN_SIZE": "16"},
465-
{"SEQUENCE_SIZE": "20", "HIDDEN_SIZE": "32"},
466-
{"SEQUENCE_SIZE": "50", "HIDDEN_SIZE": "64"},
476+
{"SEQUENCE_SIZE": 5, "HIDDEN_SIZE": 16},
477+
{"SEQUENCE_SIZE": 20, "HIDDEN_SIZE": 32},
478+
{"SEQUENCE_SIZE": 50, "HIDDEN_SIZE": 64},
467479
]
468480

469481
for i, config in enumerate(configs):
470482
logger.info(f"Testing configuration {i+1}/{len(configs)}: {config}")
471-
with patch.dict(os.environ, {**config, 'MODEL_DIR': temp_model_dir}):
472-
segmenter = TimeSeriesSegmenter(
473-
label_config=LABEL_CONFIG
474-
)
475-
segmenter.setup()
476-
477-
model = segmenter._build_model(n_channels=2, n_labels=3)
478-
logger.info(f"Created model with sequence_size={model.sequence_size}, hidden_size={model.hidden_size}")
479-
assert model.sequence_size == int(config["SEQUENCE_SIZE"])
480-
assert model.hidden_size == int(config["HIDDEN_SIZE"])
483+
segmenter = TimeSeriesSegmenter(
484+
label_config=LABEL_CONFIG
485+
)
486+
# Override instance attributes with test values
487+
segmenter.MODEL_DIR = temp_model_dir
488+
segmenter.SEQUENCE_SIZE = config["SEQUENCE_SIZE"]
489+
segmenter.HIDDEN_SIZE = config["HIDDEN_SIZE"]
490+
491+
segmenter.setup()
492+
493+
model = segmenter._build_model(n_channels=2, n_labels=3)
494+
logger.info(f"Created model with sequence_size={model.sequence_size}, hidden_size={model.hidden_size}")
495+
assert model.sequence_size == config["SEQUENCE_SIZE"]
496+
assert model.hidden_size == config["HIDDEN_SIZE"]
481497
logger.info("✓ Model parameters configuration test passed")

0 commit comments

Comments
 (0)