Skip to content

Commit 448af9d

Browse files
committed
Fix test on non cuda machine
1 parent 9fd81d7 commit 448af9d

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/tests/test_batching.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def test_packing(
9292

9393

9494
@pytest.mark.skip_missing_tokenizer
95+
@patch("llama_recipes.utils.train_utils.torch.cuda.is_bf16_supported")
9596
@patch("llama_recipes.finetuning.torch.cuda.is_available")
9697
@patch('llama_recipes.finetuning.train')
9798
@patch('llama_recipes.finetuning.AutoTokenizer')
@@ -119,6 +120,7 @@ def test_distributed_packing(
119120
tokenizer,
120121
train,
121122
cuda_is_available,
123+
cuda_is_bf16_supported,
122124
setup_tokenizer,
123125
setup_processor,
124126
llama_version,
@@ -133,6 +135,7 @@ def test_distributed_packing(
133135
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
134136
get_config.return_value = Config(model_type=model_type)
135137
cuda_is_available.return_value = False
138+
cuda_is_bf16_supported.return_value = False
136139

137140
rank = 1
138141
os.environ['LOCAL_RANK'] = f'{rank}'

0 commit comments

Comments
 (0)