Skip to content

Commit e47a886

Browse files
committed
Fix functional tests
Signed-off-by: Abhishree <[email protected]>
1 parent 598b1d8 commit e47a886

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

tests/functional_tests/recipes/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run_pretrain_recipe_test(
4747
4. No crashes occur during the process
4848
4949
Args:
50-
config_func: The recipe's pretrain_config function
50+
config_func: The recipe's pretrain_config function (parameterless API)
5151
recipe_name: Name of the recipe for logging/debugging
5252
tmp_path: Temporary directory for test outputs
5353
tensor_model_parallel_size: Override tensor parallelism (None = use recipe default)
@@ -59,9 +59,16 @@ def run_pretrain_recipe_test(
5959
shared_base_dir = broadcast_path(tmp_path)
6060

6161
try:
62-
config: ConfigContainer = config_func(
63-
dir=str(shared_base_dir), name=f"{recipe_name}_functional_test", mock=True
64-
)
62+
# Pretrain configs use parameterless API - call without arguments
63+
config: ConfigContainer = config_func()
64+
65+
# Set up output directories after instantiation
66+
run_output_dir = shared_base_dir / f"{recipe_name}_functional_test"
67+
checkpoint_dir = run_output_dir / "checkpoints"
68+
tensorboard_dir = run_output_dir / "tb_logs"
69+
config.checkpoint.save = str(checkpoint_dir)
70+
config.checkpoint.load = str(checkpoint_dir)
71+
config.logger.tensorboard_dir = str(tensorboard_dir)
6572
# Keep runs short and consistent across tests
6673
config.train.train_iters = 10
6774
config.train.eval_interval = 5
@@ -132,13 +139,14 @@ def run_pretrain_recipe_perf_test(
132139
3. No crashes occur during the process
133140
134141
Args:
135-
config_func: The recipe's pretrain_config function
142+
config_func: The recipe's pretrain_config function (parameterless API)
136143
recipe_name: Name of the recipe for logging/debugging
137144
config_overrides: Optional mapping of config attribute overrides to apply
138145
"""
139146
initialize_distributed()
140147

141-
config: ConfigContainer = config_func(name=f"{recipe_name}_functional_test", mock=True)
148+
# Pretrain configs use parameterless API - call without arguments
149+
config: ConfigContainer = config_func()
142150
# Keep runs short and consistent across tests
143151
config.train.train_iters = 10
144152
config.train.eval_interval = 5

0 commit comments

Comments
 (0)