@@ -72,10 +72,22 @@ def temp_model_dir():
7272def segmenter_instance (temp_model_dir ):
7373 """Create a TimeSeriesSegmenter instance with test configuration."""
7474 logger .info ("Creating TimeSeriesSegmenter instance for testing" )
75- with patch .dict (os .environ , {'MODEL_DIR' : temp_model_dir , 'TRAIN_EPOCHS' : '10' , 'SEQUENCE_SIZE' : '10' }):
75+ # Patch environment variables for the entire fixture scope
76+ with patch .dict (os .environ , {
77+ 'MODEL_DIR' : temp_model_dir ,
78+ 'TRAIN_EPOCHS' : '10' ,
79+ 'SEQUENCE_SIZE' : '10' ,
80+ 'HIDDEN_SIZE' : '32'
81+ }):
7682 segmenter = TimeSeriesSegmenter (
7783 label_config = LABEL_CONFIG
7884 )
85+ # Override class attributes with test values
86+ segmenter .MODEL_DIR = temp_model_dir
87+ segmenter .TRAIN_EPOCHS = 10
88+ segmenter .SEQUENCE_SIZE = 10
89+ segmenter .HIDDEN_SIZE = 32
90+
7991 segmenter .setup ()
8092 logger .info ("TimeSeriesSegmenter instance created and set up" )
8193 yield segmenter
@@ -115,10 +127,10 @@ def make_task_no_annotations():
115127 "annotations" : [],
116128 }
117129
118- def fake_preload (self , task , value = None , read_file = True ):
130+ def fake_preload (self , task , path , read_file = True ):
119131 """Mock function to preload CSV data."""
120- logger .debug (f"Mock preload called with value : { value } " )
121- return open (value ).read ()
132+ logger .debug (f"Mock preload called with path : { path } " )
133+ return open (path ).read ()
122134
123135class TestTimeSeriesSegmenter :
124136 """Test suite for TimeSeriesSegmenter with PyTorch LSTM implementation."""
@@ -461,21 +473,25 @@ def test_model_parameters_configuration(self, temp_model_dir):
461473 """Test different model parameter configurations."""
462474 logger .info ("=== Testing model parameters configuration ===" )
463475 configs = [
464- {"SEQUENCE_SIZE" : "5" , "HIDDEN_SIZE" : "16" },
465- {"SEQUENCE_SIZE" : "20" , "HIDDEN_SIZE" : "32" },
466- {"SEQUENCE_SIZE" : "50" , "HIDDEN_SIZE" : "64" },
476+ {"SEQUENCE_SIZE" : 5 , "HIDDEN_SIZE" : 16 },
477+ {"SEQUENCE_SIZE" : 20 , "HIDDEN_SIZE" : 32 },
478+ {"SEQUENCE_SIZE" : 50 , "HIDDEN_SIZE" : 64 },
467479 ]
468480
469481 for i , config in enumerate (configs ):
470482 logger .info (f"Testing configuration { i + 1 } /{ len (configs )} : { config } " )
471- with patch .dict (os .environ , {** config , 'MODEL_DIR' : temp_model_dir }):
472- segmenter = TimeSeriesSegmenter (
473- label_config = LABEL_CONFIG
474- )
475- segmenter .setup ()
476-
477- model = segmenter ._build_model (n_channels = 2 , n_labels = 3 )
478- logger .info (f"Created model with sequence_size={ model .sequence_size } , hidden_size={ model .hidden_size } " )
479- assert model .sequence_size == int (config ["SEQUENCE_SIZE" ])
480- assert model .hidden_size == int (config ["HIDDEN_SIZE" ])
483+ segmenter = TimeSeriesSegmenter (
484+ label_config = LABEL_CONFIG
485+ )
486+ # Override instance attributes with test values
487+ segmenter .MODEL_DIR = temp_model_dir
488+ segmenter .SEQUENCE_SIZE = config ["SEQUENCE_SIZE" ]
489+ segmenter .HIDDEN_SIZE = config ["HIDDEN_SIZE" ]
490+
491+ segmenter .setup ()
492+
493+ model = segmenter ._build_model (n_channels = 2 , n_labels = 3 )
494+ logger .info (f"Created model with sequence_size={ model .sequence_size } , hidden_size={ model .hidden_size } " )
495+ assert model .sequence_size == config ["SEQUENCE_SIZE" ]
496+ assert model .hidden_size == config ["HIDDEN_SIZE" ]
481497 logger .info ("✓ Model parameters configuration test passed" )
0 commit comments