Skip to content

Commit fac71cd

Browse files
committed
Fix src/tests/test_train_utils.py
1 parent 1090ccf commit fac71cd

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/tests/test_train_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ def temp_output_dir():
2727
@patch("llama_recipes.utils.train_utils.nullcontext")
2828
@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
2929
@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
30-
def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
30+
def test_gradient_accumulation(
31+
autocast,
32+
scaler,
33+
nullcontext,
34+
mem_trace,
35+
mocker):
3136

3237
model = mocker.MagicMock(name="model")
3338
model().loss.__truediv__().detach.return_value = torch.tensor(1)
@@ -47,6 +52,9 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
4752
train_config.max_train_step = 0
4853
train_config.max_eval_step = 0
4954
train_config.save_metrics = False
55+
train_config.flop_counter_start = 0
56+
train_config.use_profiler = False
57+
train_config.flop_counter = True
5058

5159
train(
5260
model,
@@ -103,6 +111,7 @@ def test_save_to_json(temp_output_dir, mocker):
103111
train_config.max_train_step = 0
104112
train_config.max_eval_step = 0
105113
train_config.output_dir = temp_output_dir
114+
train_config.flop_counter_start = 0
106115
train_config.use_profiler = False
107116

108117
results = train(

0 commit comments

Comments
 (0)