diff --git a/test/integration/test_loading_deprecated_checkpoint.py b/test/integration/test_loading_deprecated_checkpoint.py index d8bf995a7b..d60ff85b70 100644 --- a/test/integration/test_loading_deprecated_checkpoint.py +++ b/test/integration/test_loading_deprecated_checkpoint.py @@ -14,7 +14,7 @@ ) from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig -from torchao.utils import is_sm_at_least_89 +from torchao.utils import is_fbcode, is_sm_at_least_89 _MODEL_NAME_AND_VERSIONS = [ ("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev", 1), @@ -23,6 +23,10 @@ @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+") +@unittest.skipIf( + is_fbcode(), + "Skipping the test in fbcode for now, not sure how to download from transformers", +) class TestLoadingDeprecatedCheckpoint(TestCase): @common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS) def test_load_model_and_run(self, model_name_and_version):