diff --git a/.github/scripts/ci_test_xpu.sh b/.github/scripts/ci_test_xpu.sh new file mode 100644 index 0000000000..ccff1b848f --- /dev/null +++ b/.github/scripts/ci_test_xpu.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +python3 -m pip install torch torchvision torchaudio pytorch-triton-xpu --index-url https://download.pytorch.org/whl/nightly/xpu --force-reinstall --no-cache-dir +python3 setup.py install + +pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0' + +cd test/quantization +pytest -v -s *.py diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml new file mode 100644 index 0000000000..79621a06d1 --- /dev/null +++ b/.github/workflows/pr-test-xpu.yml @@ -0,0 +1,156 @@ +# TODO: this looks sort of similar to _linux-test, but there are like a dozen +# places where you would have to insert an if statement. Probably it's better to +# just use a different workflow altogether + +name: xpu-test + +on: + push: + branches: + - main + - 'gh/**' + pull_request: + branches: + - main + - 'gh/**' + +concurrency: + group: xpu_ci_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +jobs: + test: + # Don't run on forked repos or empty test matrix + # if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' + timeout-minutes: 60 + runs-on: ao-pvc + env: + DOCKER_IMAGE: ghcr.io/pytorch/ci-image:pytorch-linux-jammy-xpu-2025.1-py3-b388c12018df5d6ce2f94b7fb337fa3729978ab3 + TEST_COMMAND: .github/scripts/ci_test_xpu.sh + PYTORCH_RETRY_TEST_CASES: 1 + PYTORCH_OVERRIDE_FLAKY_SIGNAL: 1 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla + steps: + # [see note: pytorch repo ref] + - name: Checkout Torchao + uses: actions/checkout@v4 + + - name: Clean all stopped docker containers + if: always() + shell: bash + run: | + # Prune all stopped containers. + # If other runner is pruning on this node, will skip. + nprune=$(ps -ef | grep -c "docker container prune") + if [[ $nprune -eq 1 ]]; then + docker container prune -f + fi + + - name: Runner health check GPU count + if: always() + shell: bash + run: | + ngpu=$(timeout 30 clinfo -l | grep -c -E 'Device' || true) + msg="Please file an issue on pytorch/ao reporting the faulty runner. Include a link to the runner logs so the runner can be identified" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + echo "$msg" + exit 1 + fi + + - name: Use following to pull public copy of the image + id: print-ghcr-mirror + shell: bash + run: | + echo "docker pull ${DOCKER_IMAGE}" + docker pull ${DOCKER_IMAGE} + + - name: Test + id: test + env: + BUILD_ENVIRONMENT: ${{ inputs.build-environment }} + PR_NUMBER: ${{ github.event.pull_request.number }} + GITHUB_REPOSITORY: ${{ github.repository }} + GITHUB_WORKFLOW: ${{ github.workflow }} + GITHUB_JOB: ${{ github.job }} + GITHUB_RUN_ID: ${{ github.run_id }} + GITHUB_RUN_NUMBER: ${{ github.run_number }} + GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + timeout-minutes: 60 + run: | + set -x + + # detached container should get cleaned up by teardown_ec2_linux + # Used for GPU_FLAG since that doesn't play nice + # shellcheck disable=SC2086,SC2090 + container_name=$(docker run \ + ${GPU_FLAG:-} \ + -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ + -e GITHUB_ACTIONS \ + -e GITHUB_REPOSITORY \ + -e GITHUB_WORKFLOW \ + -e GITHUB_JOB \ + -e GITHUB_RUN_ID \ + -e GITHUB_RUN_NUMBER \ + -e GITHUB_RUN_ATTEMPT \ + -e JOB_ID \ + -e BRANCH \ + -e SHA1 \ + --user $(id -u):$(id -g) \ + --ulimit stack=10485760:83886080 \ + --ulimit core=0 \ + --security-opt seccomp=unconfined \ + --cap-add=SYS_PTRACE \ + --shm-size="8g" \ + --tty \ + --detach \ + --name="${container_name}" \ + --user jenkins \ + --privileged \ + -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ + -w /var/lib/jenkins/workspace \ + "${DOCKER_IMAGE}" + ) + # save container name for later step + echo "CONTAINER_NAME=${container_name}" >> "$GITHUB_ENV" + # jenkins user does not have write permission to mounted workspace; work-around by copying within container to jenkins home + docker exec -t "${container_name}" sh -c "bash ${TEST_COMMAND}" + + - name: Change permissions + if: ${{ always() && steps.test.conclusion }} + run: | + docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "sudo chown -R 1001:1001 test" + + - name: Collect backtraces from coredumps (if any) + if: always() + run: | + # shellcheck disable=SC2156 + find . -iname "core.[1-9]*" -exec docker exec "${CONTAINER_NAME}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \; + + - name: Stop container before exit + if: always() + run: | + # Workaround for multiple runners on same IDC node + docker stop "${{ env.CONTAINER_NAME }}" + + - name: Store Core dumps on GitHub + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: failure() + with: + name: coredumps-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }} + retention-days: 14 + if-no-files-found: ignore + path: ./**/core.[1-9]* + + - name: Teardown XPU + if: always() + shell: bash + run: | + # Prune all stopped containers. + # If other runner is pruning on this node, will skip. + nprune=$(ps -ef | grep -c "docker container prune") + if [[ $nprune -eq 1 ]]; then + docker container prune -f + fi diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index bd5ed0c3b5..d86fe1b3a4 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -48,15 +48,18 @@ is_ROCM, is_sm_at_least_89, is_sm_at_least_90, + auto_detect_device, ) is_cusparselt_available = ( hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available() ) +_DEVICE = auto_detect_device() + def get_quantization_functions( - do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False + do_sparse: bool, do_int4: bool, device: str =_DEVICE, int4_zp_int: bool = False ): base_functions = [ int8_weight_only(), @@ -114,9 +117,9 @@ class TestAffineQuantized(TestCase): ["xpu"] if torch.xpu.is_available() else [] ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tensor_core_layout_transpose(self): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) t = linear.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) @@ -182,7 +185,7 @@ def _apply(module, config_or_subclass_inserter): ql = _apply(linear, apply_quant) ql.to(device) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_register_new_dispatch(self): from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes.affine_quantized_tensor_ops import ( @@ -219,10 +222,10 @@ def apply_uint6_weight_only_quant(linear): ) return linear - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) apply_uint6_weight_only_quant(linear) - example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE) with self.assertRaisesRegex( AssertionError, "dispatching to my impl for uint6 weight only quant" ): @@ -245,13 +248,13 @@ def test_print_quantized_module(self): ql = apply_quant(linear) assert "AffineQuantizedTensor" in str(ql) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize( - "apply_quant", get_quantization_functions(False, True, "cuda", False) + "apply_quant", get_quantization_functions(False, True, _DEVICE, False) ) def test_test_copy__apply(self, apply_quant): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) + linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) @@ -262,20 +265,20 @@ def test_test_copy__apply(self, apply_quant): ql = apply_quant(linear) ql2 = apply_quant(linear2) - example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE) output = ql(example_input) ql2.weight.copy_(ql.weight) ql2.bias = ql.bias output2 = ql2(example_input) self.assertEqual(output, output2) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize( - "apply_quant", get_quantization_functions(False, True, "cuda", False) + "apply_quant", get_quantization_functions(False, True, _DEVICE, False) ) def test_copy__mismatch_metadata(self, apply_quant): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) + linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device=_DEVICE) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) @@ -349,9 +352,8 @@ def test_alias(self, device, dtype): quantize_(dummy, Int8DynamicActivationInt8WeightConfig()) _ = dummy.weight[...] - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) - @skip_if_no_cuda() @skip_if_rocm("ROCm enablement in progress") def test_slice_int4wo(self, device, dtype): # in_feature not divisible by 1024 @@ -363,9 +365,7 @@ def test_slice_int4wo(self, device, dtype): _ = dummy.weight.narrow(0, 0, 64) _ = dummy.weight.narrow(1, 0, 128) - @common_utils.parametrize("device", ["cuda"]) @common_utils.parametrize("dtype", [torch.float16, torch.bfloat16]) - @skip_if_no_cuda() @skip_if_no_gemlite() def test_slice_gemlite(self, device, dtype): # in_feature not divisible by 1024 @@ -446,7 +446,7 @@ def dequant(input_layer, in_features, orig_shape): ) self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0) - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) def test_matmul(self, device, dtype): x = torch.randn(53, 2048) @@ -463,14 +463,13 @@ def test_matmul(self, device, dtype): # make sure it runs torch.matmul(x, w.t()) - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) - @skip_if_no_cuda() @skip_if_rocm("ROCm enablement in progress") def test_slice_and_copy_int4wo(self, device, dtype): - l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16) l.weight = torch.nn.Parameter( - torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE) ) quantize_(l, Int4WeightOnlyConfig()) param = l.weight @@ -487,7 +486,7 @@ def test_slice_and_copy_int4wo(self, device, dtype): assert param.data.dequantize()[0][0] == 0 # dummy_l has random input (shouldn't be 0) - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + dummy_l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16) quantize_(dummy_l, Int4WeightOnlyConfig()) quantized = dummy_l.weight quantized = quantized.narrow(0, 0, 512) @@ -497,9 +496,8 @@ def test_slice_and_copy_int4wo(self, device, dtype): # making sure param.data is updated assert param.data.dequantize()[0][0] != 0 - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) - @skip_if_no_cuda() @skip_if_rocm("ROCm enablement in progress") def test_mm_int4wo(self, device, dtype): weight = torch.randn(512, 1024).to(device).to(dtype) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index ee1849a289..2e7b8adac3 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -56,6 +56,10 @@ random.seed(0) torch.manual_seed(0) +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + class ToyLinearModel(torch.nn.Module): def __init__(self, in_features, out_features): @@ -70,9 +74,8 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) @@ -106,7 +109,7 @@ def test_fp8_linear_variants( with error_context: M, N, K = sizes - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE) # Get a "reasonable" scale for the input tensor even though # we use the same scale for multiple activations scale, _ = choose_qparams_affine( @@ -129,7 +132,7 @@ def test_fp8_linear_variants( } # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyLinearModel(K, N).eval().to(dtype).to(_DEVICE) quantized_model = copy.deepcopy(model) factory = mode_map[mode]() @@ -147,14 +150,14 @@ def test_fp8_linear_variants( ) @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_mismatched_granularity(self): with pytest.raises( @@ -164,7 +167,7 @@ def test_mismatched_granularity(self): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_unsupported_granularity(self): class UnsupportedGranularity: @@ -175,28 +178,26 @@ class UnsupportedGranularity: granularity=(UnsupportedGranularity(), UnsupportedGranularity()) ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_per_row_with_float32(self): with pytest.raises( AssertionError, match="PerRow quantization only works for bfloat16 precision", ): - model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") + model = ToyLinearModel(64, 64).eval().to(torch.float32).to(_DEVICE) quantize_( model, float8_dynamic_activation_float8_weight(granularity=PerRow()) ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model - model = ToyLinearModel(16, 32).to(device="cuda") + model = ToyLinearModel(16, 32).to(device=_DEVICE) mode_map = { "dynamic": partial( @@ -205,7 +206,7 @@ def test_serialization(self, mode: str): "weight-only": float8_weight_only, "static": partial( float8_static_activation_float8_weight, - scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"), + scale=torch.tensor(1.0, dtype=torch.float32, device=_DEVICE), granularity=PerTensor(), ), } @@ -261,9 +262,8 @@ def test_serialization(self, mode: str): original_layer.weight.scale, new_layer.weight.scale ), f"Scales do not match for {layer_name}" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) @@ -296,9 +296,8 @@ def test_fp8_weight_dimension_warning(self): f"Expected warning message containing: {expected}", ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize( "in_features,out_features", [(512, 1024), (256, 768), (1024, 512)] @@ -310,12 +309,12 @@ def test_fp8_weight_dimension_warning(self): def test_mm_float8dq_per_row( self, in_features, out_features, leading_shape, bias: bool ): - device = "cuda" + dtype = torch.bfloat16 input_shape = leading_shape + (in_features,) ref_linear = ( - torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype) + torch.nn.Linear(in_features, out_features, bias=bias).to(_DEVICE).to(dtype) ) test_linear = copy.deepcopy(ref_linear) quantize_( @@ -338,7 +337,7 @@ def test_mm_float8dq_per_row( self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features)) - input_tensor = torch.randn(*input_shape, device=device, dtype=dtype) + input_tensor = torch.randn(*input_shape, device=_DEVICE, dtype=dtype) with torch.no_grad(): ref_output = ref_linear(input_tensor) @@ -350,16 +349,15 @@ def test_mm_float8dq_per_row( error = compute_error(ref_output, quant_output) assert error > 20, f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): block_size = () - device = "cuda" - input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) + + input_tensor = torch.randn(8, 64, device=_DEVICE, dtype=torch.float32) # testing upper bounds input_tensor[0][0] = 2000 @@ -379,7 +377,7 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): # tesing lower bounds settings # making sure that abs is on the scale of 1e-20, so hp_value_lb can take effect - input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) * 1e-20 + input_tensor = torch.randn(8, 64, device=_DEVICE, dtype=torch.float32) * 1e-20 scale_ref = _choose_scale_float8( input_tensor, float8_dtype=float8_dtype, block_size=block_size ) @@ -393,9 +391,8 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): # since scale = abs_max / quant_max, larger abs_max means scale is larger self.assertTrue(scale_ref < scale_with_lb) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) @@ -403,8 +400,8 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): """Test _dequantize_affine_float8 with various configurations""" - device = "cuda" - input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) + + input_tensor = torch.randn(8, 64, device=_DEVICE, dtype=torch.float32) # Choose quantization parameters scale = _choose_scale_float8( @@ -426,15 +423,14 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): error = torch.abs(input_tensor.to(output_dtype) - dequantized).mean() self.assertLess(error, 0.1, "Quantization error too high") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_dequantize_affine_float8_scale_broadcasting(self): """Test that scale broadcasting works correctly for block-wise quantization""" - device = "cuda" + # Create input tensor with known block structure - input_tensor = torch.randn(4, 32, device=device, dtype=torch.float32) + input_tensor = torch.randn(4, 32, device=_DEVICE, dtype=torch.float32) block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim # Choose quantization parameters @@ -458,18 +454,17 @@ def test_dequantize_affine_float8_scale_broadcasting(self): # Verify shapes match self.assertEqual(dequantized.shape, input_tensor.shape) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) def test_float8_tensor_slicing_basic(self, granularity): """Test basic slicing operations on Float8 tensors""" - device = "cuda" + dtype = torch.bfloat16 # Create and quantize a model - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + model = torch.nn.Linear(64, 32, bias=False).to(_DEVICE).to(dtype) quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) ) @@ -493,17 +488,16 @@ def test_float8_tensor_slicing_basic(self, granularity): self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl)) self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl)) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_float8_tensor_slicing_per_tensor(self): """Test slicing with per-tensor quantization (scale should not change)""" - device = "cuda" + dtype = torch.bfloat16 # Create and quantize with per-tensor granularity - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + model = torch.nn.Linear(64, 32, bias=False).to(_DEVICE).to(dtype) quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) ) @@ -520,9 +514,8 @@ def test_float8_tensor_slicing_per_tensor(self): self.assertTrue(torch.equal(original_scale, sliced_impl.scale)) self.assertEqual(sliced_impl.scale.numel(), 1) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @unittest.skipIf( not is_sm_at_least_90(), @@ -530,11 +523,11 @@ def test_float8_tensor_slicing_per_tensor(self): ) def test_float8_tensor_slicing_per_row(self): """Test slicing with per-row quantization (scale should be sliced appropriately)""" - device = "cuda" + dtype = torch.bfloat16 # Create and quantize with per-row granularity - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + model = torch.nn.Linear(64, 32, bias=False).to(_DEVICE).to(dtype) quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) ) @@ -562,17 +555,16 @@ def test_float8_tensor_slicing_per_row(self): self.assertEqual(sliced_cols_impl.scale.shape, (32, 1)) self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale)) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) def test_float8_tensor_slicing_edge_cases(self): """Test edge cases in slicing""" - device = "cuda" + dtype = torch.bfloat16 # Create and quantize a model - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) + model = torch.nn.Linear(64, 32, bias=False).to(_DEVICE).to(dtype) quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) ) @@ -591,9 +583,8 @@ def test_float8_tensor_slicing_edge_cases(self): large_slice = original_weight[:100] # More than available rows self.assertEqual(large_slice.shape, (32, 64)) # Should clamp to available - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @unittest.skipIf( @@ -602,12 +593,12 @@ def test_float8_tensor_slicing_edge_cases(self): ) def test_float8_tensor_slicing_functional_correctness(self, granularity): """Test that sliced tensors produce correct results in computations""" - device = "cuda" + dtype = torch.bfloat16 # Create reference and quantized models with dimensions that are multiples of 16 ref_model = ( - torch.nn.Linear(64, 48, bias=False).to(device).to(dtype) + torch.nn.Linear(64, 48, bias=False).to(_DEVICE).to(dtype) ) # 48 is divisible by 16 quant_model = copy.deepcopy(ref_model) quantize_( @@ -616,7 +607,7 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity): ) # Create input with batch size that works well with slicing - input_tensor = torch.randn(8, 64, device=device, dtype=dtype) + input_tensor = torch.randn(8, 64, device=_DEVICE, dtype=dtype) ref_weight_slice = ref_model.weight[0:16, 0:32] quant_weight_slice = quant_model.weight[0:16, 0:32] @@ -678,14 +669,14 @@ def test_preprocess_scale_3d_reshape(self): device = "cpu" # Use CPU for basic functionality test # Test 1: PerTensor scale (scalar) - should reshape to (1, 1) - per_tensor_scale = torch.tensor(0.5, device=device) + per_tensor_scale = torch.tensor(0.5, device=_DEVICE) result = preprocess_scale(per_tensor_scale, (2, 4, 8)) expected_shape = (1, 1) self.assertEqual(result.shape, expected_shape) self.assertEqual(result.item(), 0.5) # Test 2: 1D scale tensor with one element - should reshape to (1, 1) - one_element_scale = torch.tensor([0.3], device=device) + one_element_scale = torch.tensor([0.3], device=_DEVICE) result = preprocess_scale(one_element_scale, (2, 4, 8)) expected_shape = (1, 1) self.assertEqual(result.shape, expected_shape) @@ -694,7 +685,7 @@ def test_preprocess_scale_3d_reshape(self): # Test 3: 3D scale tensor for per-row quantization - should flatten first N-1 dims # This is the key test for the 3D reshape fix scale_3d = torch.randn( - 2, 4, device=device + 2, 4, device=_DEVICE ) # Shape matches first 2 dims of (2, 4, 8) result = preprocess_scale(scale_3d, (2, 4, 8)) expected_shape = (8, 1) # Flattened (2*4, 1) @@ -705,14 +696,14 @@ def test_preprocess_scale_3d_reshape(self): self.assertTrue(torch.allclose(result, expected_values)) # Test 4: 2D scale tensor (already correct shape) - should just add last dimension - scale_2d = torch.randn(8, device=device) + scale_2d = torch.randn(8, device=_DEVICE) result = preprocess_scale(scale_2d, (8, 16)) expected_shape = (8, 1) self.assertEqual(result.shape, expected_shape) # Test 5: Edge case with higher dimensions (4D) scale_4d = torch.randn( - 2, 2, 2, device=device + 2, 2, 2, device=_DEVICE ) # Shape matches first 3 dims of (2, 2, 2, 8) result = preprocess_scale(scale_4d, (2, 2, 2, 8)) expected_shape = (8, 1) # Flattened (2*2*2, 1) diff --git a/test/dtypes/test_bitpacking.py b/test/dtypes/test_bitpacking.py index 0ed4462d5d..dd13ec852b 100644 --- a/test/dtypes/test_bitpacking.py +++ b/test/dtypes/test_bitpacking.py @@ -9,6 +9,11 @@ from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + + bit_widths = (1, 2, 3, 4, 5, 6, 7) dimensions = (0, -1, 1) @@ -30,17 +35,17 @@ def test_CPU(bit_width, dim): assert unpacked.allclose(test_tensor) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) def test_GPU(bit_width, dim): - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to(_DEVICE) packed = pack(test_tensor, bit_width, dim=dim) unpacked = unpack(packed, bit_width, dim=dim) assert unpacked.allclose(test_tensor) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) @@ -48,22 +53,22 @@ def test_compile(bit_width, dim): torch._dynamo.config.specialize_int = True torch.compile(pack, fullgraph=True) torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to(_DEVICE) packed = pack(test_tensor, bit_width, dim=dim) unpacked = unpack(packed, bit_width, dim=dim) assert unpacked.allclose(test_tensor) # these test cases are for the example pack walk through in the bitpacking.py file -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_pack_example(): test_tensor = torch.tensor( [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 - ).cuda() + ).to(_DEVICE) shard_4, shard_2 = pack(test_tensor, 6) print(shard_4, shard_2) - assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) - assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) + assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).to(_DEVICE).allclose(shard_4) + assert torch.tensor([39, 146], dtype=torch.uint8).to(_DEVICE).allclose(shard_2) unpacked = unpack([shard_4, shard_2], 6) assert unpacked.allclose(test_tensor) diff --git a/test/dtypes/test_fbgemm_fp8.py b/test/dtypes/test_fbgemm_fp8.py index ea869a1c39..5d811eb9af 100644 --- a/test/dtypes/test_fbgemm_fp8.py +++ b/test/dtypes/test_fbgemm_fp8.py @@ -21,12 +21,14 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90, + auto_detect_device, ) +_DEVICE = auto_detect_device() + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") +@unittest.skipIf(_DEVICE == "cuda" and not is_sm_at_least_90(), "Nedd sm90+") class TestFbgemmFp8Tensor(TestCase): def setUp(self): self.config = FbgemmConfig( @@ -40,11 +42,10 @@ def setUp(self): output_dtype=torch.bfloat16, transpose_input=True, ) - self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] def test_linear(self): dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) @@ -54,7 +55,7 @@ def test_linear(self): def test_slice(self): dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) dummy1.weight = torch.nn.Parameter( @@ -123,7 +124,7 @@ def forward(self, x): return torch.bmm(x, self.weight) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE input = torch.randn(10, 32, 128, dtype=dtype, device=device) weight = torch.randn(10, 128, 256, dtype=dtype, device=device) m = M(weight).eval() @@ -135,7 +136,7 @@ def forward(self, x): self.assertTrue(compute_error(original, quantized) > 20) def test_to_device(self): - for device in self.GPU_DEVICES: + for device in _DEVICE: linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) quantize_(linear, self.config) linear.to(device) diff --git a/test/dtypes/test_fbgemm_int4.py b/test/dtypes/test_fbgemm_int4.py index eb1f059775..d74689057b 100644 --- a/test/dtypes/test_fbgemm_int4.py +++ b/test/dtypes/test_fbgemm_int4.py @@ -20,12 +20,13 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90, + auto_detect_device, ) +_DEVICE = auto_detect_device() @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") +@unittest.skipIf(_DEVICE=="cuda" and not is_sm_at_least_90(), "Nedd sm90+") class TestFbgemmInt4Tensor(TestCase): def setUp(self): self.config = FbgemmConfig( @@ -40,11 +41,10 @@ def setUp(self): output_dtype=torch.bfloat16, block_size=[1, 1, 128], ) - self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] def test_linear(self): dtype = torch.bfloat16 - device = "cuda" + device=_DEVICE input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) @@ -54,7 +54,7 @@ def test_linear(self): def test_slice(self): dtype = torch.bfloat16 - device = "cuda" + device=_DEVICE dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) dummy1.weight = torch.nn.Parameter( @@ -92,9 +92,9 @@ def test_slice(self): assert compute_error(res, res_ref) > 15 def test_slice_and_copy_(self): - l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16) l.weight = torch.nn.Parameter( - torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE) ) quantize_(l, self.config) param = l.weight @@ -108,7 +108,7 @@ def test_slice_and_copy_(self): orig_value = param.data.packed_weight[0][0].item() # dummy_l has random input (shouldn't be 0) - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + dummy_l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16) quantize_(dummy_l, self.config) quantized = dummy_l.weight quantized = quantized.narrow(0, 0, 512) @@ -128,7 +128,7 @@ def forward(self, x): return torch.bmm(x, self.weight) dtype = torch.bfloat16 - device = "cuda" + device=_DEVICE input = torch.randn(10, 32, 128, dtype=dtype, device=device) weight = torch.randn(10, 128, 256, dtype=dtype, device=device) m = M(weight).eval() @@ -140,7 +140,7 @@ def forward(self, x): self.assertTrue(compute_error(original, quantized) > 18) def test_to_device(self): - for device in self.GPU_DEVICES: + for device in _DEVICE: linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) quantize_(linear, self.config) linear.to(device) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 237bc2bd92..93807670df 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -33,9 +33,9 @@ quantize_, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, get_available_devices -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_DEVICES = get_available_devices() _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -87,7 +87,7 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): ) torch.testing.assert_close(actual, expected) - @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("device", _DEVICES) @parametrize("ebits,mbits", _Floatx_DTYPES) def test_to_copy_device(self, ebits, mbits): from torchao.quantization.quant_primitives import ( @@ -101,12 +101,10 @@ def test_to_copy_device(self, ebits, mbits): _layout = FloatxTensorCoreLayout(ebits, mbits) floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain( x, scale, None, _layout - ).cuda() - assert floatx_tensor_impl.device.type == "cuda" - floatx_tensor_impl = floatx_tensor_impl.cpu() - assert floatx_tensor_impl.device.type == "cpu" + ).to(device) + assert floatx_tensor_impl.device.type == device - @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("device", _DEVICES) @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+", @@ -116,9 +114,8 @@ def test_to_copy_device(self, ebits, mbits): @parametrize("dtype", [torch.half, torch.bfloat16]) @unittest.skipIf(is_fbcode(), reason="broken in fbcode") @skip_if_rocm("ROCm enablement in progress") - def test_fpx_weight_only(self, ebits, mbits, bias, dtype): + def test_fpx_weight_only(self, device, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 - device = "cuda" linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype) fpx_linear = copy.deepcopy(linear) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 35c722365d..8223be5fe9 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -15,6 +15,7 @@ quantize_affine, ) from torchao.utils import ( + get_available_devices, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, ) @@ -34,7 +35,7 @@ dtypes = () group_sizes = [32, 64, 128] -devices = ["cpu", "cuda"] +devices = get_available_devices() @pytest.fixture(autouse=True) @@ -61,31 +62,29 @@ def __init__(self, scale, device): def forward(self, x): return self.net(x) - +@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" ) -def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): +def test_uintx_quant_on_cpu_then_move_to_cuda(device, dtype, group_size): scale = 512 fp16_mod_on_cpu = Linear16(scale, "cpu") quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size)) test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu") output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu) - fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda") - test_input_on_cuda = test_input_on_cpu.to("cuda") - output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda) - assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), ( - "The output of the model on CPU and CUDA should be close" + fp16_mod_on_gpu = fp16_mod_on_cpu.to(device) + test_input_on_gpu = test_input_on_cpu.to(device) + output_on_gpu = fp16_mod_on_gpu(test_input_on_gpu) + assert torch.allclose(output_on_cpu, output_on_gpu.cpu(), atol=1.0e-3), ( + "The output of the model on CPU and GPU should be close" ) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" ) @@ -102,7 +101,6 @@ def test_uintx_weight_only_model_quant(dtype, group_size, device): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" ) @@ -139,41 +137,41 @@ def test_uintx_weight_only_quant(dtype, group_size, device): @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") +@pytest.mark.parametrize("device", devices) @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" ) -def test_uintx_target_dtype(dtype): +def test_uintx_target_dtype(dtype, device): from torchao.quantization.quant_api import uintx_weight_only - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) # make sure it runs quantize_(linear, uintx_weight_only(dtype)) - linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device)) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") +@pytest.mark.parametrize("device", devices) @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_5, reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+", ) -def test_uintx_target_dtype_compile(dtype): +def test_uintx_target_dtype_compile(dtype, device): from torchao.quantization.quant_api import uintx_weight_only - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) # make sure it runs quantize_(linear, uintx_weight_only(dtype)) linear = torch.compile(linear) - linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device)) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") +@pytest.mark.parametrize("device", devices) @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" ) -def test_uintx_model_size(dtype): +def test_uintx_model_size(dtype, device): from torchao.quantization.quant_api import uintx_weight_only from torchao.utils import get_model_size_in_bytes @@ -190,7 +188,7 @@ def test_uintx_model_size(dtype): torch.uint7: (7 / 8 + 1 / 16 + 1 / 32) / 2, } linear = torch.nn.Sequential( - torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda") + torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device=device) ) bf16_size = get_model_size_in_bytes(linear) # make sure it runs diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c19478e02a..e0b1d4ab7e 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -55,6 +55,10 @@ from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.utils import is_MI300, is_ROCM +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + random.seed(0) torch.manual_seed(0) @@ -236,11 +240,10 @@ def test_axiswise_reshape(self): (ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE), ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): - a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") - b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + a = torch.randn(*a_shape, dtype=torch.bfloat16, device=_DEVICE) + b = torch.randn(64, 32, dtype=torch.bfloat16, device=_DEVICE) linear_mm_config = LinearMMConfig() @@ -269,7 +272,6 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): sqnr = compute_error(c_ref, c_fp8_compute) assert sqnr >= 25.0 - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fp8_dtype( self, ): @@ -334,7 +336,6 @@ def _test_linear_impl( @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @pytest.mark.parametrize("use_ac", [False, True]) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_from_config_params( self, x_shape, @@ -346,8 +347,8 @@ def test_linear_from_config_params( linear_bias: bool, use_ac: bool, ): - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device=_DEVICE, dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device=_DEVICE, dtype=linear_dtype) config = get_test_float8_linear_config( scaling_type_input, @@ -379,7 +380,6 @@ def test_linear_from_config_params( @pytest.mark.parametrize( "linear_dtype", [torch.bfloat16, torch.float16, torch.float32] ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.skipIf( torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA capability < 9.0" ) @@ -391,8 +391,8 @@ def test_linear_from_recipe( linear_dtype: torch.dtype, linear_bias: bool, ): - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device=_DEVICE, dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device=_DEVICE, dtype=linear_dtype) config = Float8LinearConfig.from_recipe_name(recipe_name) self._test_linear_impl( x, @@ -414,7 +414,6 @@ def test_linear_from_recipe( Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast_outputs( self, emulate: bool, @@ -422,8 +421,8 @@ def test_autocast_outputs( recipe_name: Float8LinearRecipeName, ): m_ref = nn.Sequential( - nn.Linear(32, 32, device="cuda", dtype=linear_dtype), - nn.Linear(32, 32, device="cuda", dtype=linear_dtype), + nn.Linear(32, 32, device=_DEVICE, dtype=linear_dtype), + nn.Linear(32, 32, device=_DEVICE, dtype=linear_dtype), ) config = Float8LinearConfig.from_recipe_name(recipe_name) # work around config being frozen @@ -432,16 +431,16 @@ def test_autocast_outputs( m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) # autocast off - x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + x = torch.randn(16, 32, device=_DEVICE, dtype=linear_dtype) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on - with torch.autocast("cuda"): + with torch.autocast(_DEVICE): y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" - with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.autocast(_DEVICE, dtype=torch.bfloat16): y = m(x) assert y.dtype == torch.bfloat16, ( f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -459,18 +458,18 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s - @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") + @unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): - x = torch.randn(32, 32, device="cuda") - m = nn.Sequential(nn.Linear(32, 32)).cuda() + x = torch.randn(32, 32, device=_DEVICE) + m = nn.Sequential(nn.Linear(32, 32)).to(_DEVICE) m = convert_to_float8_training(m) with torch.inference_mode(mode=True): m(x) - @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") + @unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA arch 8.9 not available") def test_quantize(self): - x = torch.randn(32, 32, device="cuda") - m = nn.Sequential(nn.Linear(32, 32)).cuda() + x = torch.randn(32, 32, device=_DEVICE) + m = nn.Sequential(nn.Linear(32, 32)).to(_DEVICE) m = convert_to_float8_training(m) assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" from torchao.quantization.quant_api import float8_weight_only, quantize_ @@ -485,7 +484,7 @@ def test_quantize(self): class TestScaledMM: @unittest.skipIf( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -498,8 +497,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): output_dtype = base_dtype compare_type = torch.float32 - a = torch.randn(16, 16, device="cuda", dtype=base_dtype) - b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + a = torch.randn(16, 16, device=_DEVICE, dtype=base_dtype) + b = torch.randn(32, 16, device=_DEVICE, dtype=base_dtype).t() a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() @@ -530,10 +529,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") + @unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA not available") def test_different_configs_error(self): - x_fp32 = torch.randn(16, 16, device="cuda") - x_scale = torch.tensor(1.0, device="cuda") + x_fp32 = torch.randn(16, 16, device=_DEVICE) + x_scale = torch.tensor(1.0, device=_DEVICE) fp8_dtype = e4m3_dtype linear_config_a = LinearMMConfig( ScaledMMConfig(False, True, False, False), @@ -566,7 +565,7 @@ def test_different_configs_error(self): a @ b @unittest.skipIf( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -578,8 +577,8 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): input_dtype = e4m3_dtype compare_type = torch.float32 - a = torch.randn(16, 41, device="cuda", dtype=base_dtype) - b = torch.randn(41, 128, device="cuda", dtype=base_dtype) + a = torch.randn(16, 41, device=_DEVICE, dtype=base_dtype) + b = torch.randn(41, 128, device=_DEVICE, dtype=base_dtype) a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() @@ -657,7 +656,6 @@ class TestNumerics: torch.float8_e5m2fnuz, ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_small_amax_float16(self, float8_dtype): # If we calculate scale naively with FP8_MAX_POS / amax, # the result may not be representable in fp16. Verify that @@ -676,7 +674,7 @@ def test_small_amax_float16(self, float8_dtype): FP16_MAX_POS = torch.finfo(torch.float16).max target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12) - x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") + x = torch.tensor([target_amax], dtype=torch.float16, device=_DEVICE) scale = tensor_to_scale(x, float8_dtype) assert not torch.any(torch.isinf(scale)) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index a196d87430..d9c6400bf0 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -43,6 +43,9 @@ ) from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() def _test_compile_base( backend: str, @@ -55,9 +58,9 @@ def _test_compile_base( x_shape = (16, 16) linear_dtype = torch.bfloat16 - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + x = torch.randn(*x_shape, device=_DEVICE, dtype=linear_dtype).requires_grad_() x_ref = copy.deepcopy(x) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=True, device=_DEVICE, dtype=linear_dtype) m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), @@ -92,7 +95,6 @@ def _test_compile_base( ) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( fullgraph, emulate: bool, @@ -128,7 +130,6 @@ def test_eager_only( [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, emulate: bool, @@ -164,7 +165,7 @@ def test_aot_eager( [ScalingType.DYNAMIC], ) @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA with float8 support not available", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -203,7 +204,7 @@ def test_inductor_from_config_params( ], ) @unittest.skipIf( - not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" + torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" ) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() @@ -238,23 +239,23 @@ def forward(self, x): # TODO(future): figure out why the test below fails on CUDA capability 8.9 @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_90(), + torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8TrainingTensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=True).cuda() + mod = self.MockLinear(graph_break=True).to(_DEVICE) compiled_mod = copy.deepcopy(mod) compiled_mod = torch.compile(compiled_mod, backend=cnts) - x = torch.randn(16, 16, device="cuda") + x = torch.randn(16, 16, device=_DEVICE) y_eager = mod(x) y_compiled = compiled_mod(x) self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") torch.testing.assert_close(y_eager, y_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), + torch.cuda.is_available() and not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_input(self): @@ -264,8 +265,8 @@ def to_float(x): return x.to_original_precision() cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=False).cuda() - x = torch.randn(2, 2, device="cuda") + mod = self.MockLinear(graph_break=False).to(_DEVICE) + x = torch.randn(2, 2, device=_DEVICE) compiled_to_float = torch.compile(to_float, backend=cnts) y = mod(x) y2_eager = to_float(y) @@ -278,15 +279,15 @@ def to_float(x): torch.testing.assert_close(y2_eager, y2_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), + torch.cuda.is_available() and not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_output(self): """Test that having Float8TrainingTensor object as a graph output works""" cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=False).cuda() + mod = self.MockLinear(graph_break=False).to(_DEVICE) compiled_mod = torch.compile(mod, backend=cnts) - x = torch.randn(16, 16, device="cuda") + x = torch.randn(16, 16, device=_DEVICE) y_compiled = compiled_mod(x) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") @@ -325,7 +326,7 @@ def __exit__(self, *args): @unittest.skipIf( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -348,7 +349,7 @@ def test_dynamic_scale_numeric_parity( ): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) - hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) + hp_tensor1 = torch.randn(16, 16, device=_DEVICE, dtype=dtype) hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index db02444109..d80e384cf1 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -35,6 +35,10 @@ from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + torch.manual_seed(0) @@ -94,7 +98,7 @@ def _test_impl(self, config: Float8LinearConfig) -> None: multiple_of=1024, ffn_dim_multiplier=1.3, ) - .cuda() + .to(_DEVICE) .to(data_dtype) ) @@ -115,8 +119,8 @@ def _test_impl(self, config: Float8LinearConfig) -> None: # logic of delayed scaling behaves as dynamic scaling # TODO(future PR): delete ^, since we deleted delayed scaling shape = (1, 8192, 4096) - data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) - data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) + data1 = torch.randn(*shape, device=_DEVICE, dtype=data_dtype) + data2 = torch.randn(*shape, device=_DEVICE, dtype=data_dtype) model_ref(data1).sum().backward() # zero out grads without stepping, since we just want to compare grads @@ -160,7 +164,7 @@ def _test_impl(self, config: Float8LinearConfig) -> None: [ScalingType.DYNAMIC], ) @pytest.mark.skipif( - not is_sm_at_least_89(), reason="requires SM89 compatible machine" + torch.cuda.is_available() and not is_sm_at_least_89(), reason="requires SM89 compatible machine" ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_config_params( diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index a6990549a3..d484a7e629 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -17,12 +17,14 @@ from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, + auto_detect_device, ) cuda_available = torch.cuda.is_available() # Parameters -device = "cuda:0" +device = auto_detect_device() +print("Testing on ", device) compute_dtype = torch.bfloat16 group_size = 64 mapping_type = MappingType.ASYMMETRIC @@ -77,7 +79,6 @@ def _eval_hqq(dtype): return dequantize_error, dot_product_error -@unittest.skipIf(not cuda_available, "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+") class TestHQQ(unittest.TestCase): def _test_hqq( diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 92e670a95b..b6f632d2d1 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -21,6 +21,10 @@ import torch from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() +print("Testing on ", _DEVICE) # Test configs SHAPES = [ @@ -92,7 +96,7 @@ def test_mixed_mm( } M, N, K = shape - linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=_DEVICE) quant_config = BaseQuantizeConfig( quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False @@ -115,7 +119,7 @@ def test_mixed_mm( packed_w = pack_2xint4(W_q.T) if transposed: - x = torch.randn(M, N, dtype=dtype, device="cuda") + x = torch.randn(M, N, dtype=dtype, device=_DEVICE) hqq_out = x @ W_dq tt_out = triton_mixed_mm( @@ -130,7 +134,7 @@ def test_mixed_mm( ) else: - x = torch.randn(M, K, dtype=dtype, device="cuda") + x = torch.randn(M, K, dtype=dtype, device=_DEVICE) hqq_out = x @ W_dq.T tt_out = triton_mixed_mm( @@ -174,9 +178,9 @@ def _test_mixed_mm( quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False ) quant_config.update({"weight_quant_params": qcfg}) - W_q = torch.randint(0, int(2**4), size=(N, K), dtype=quant_dtype, device="cuda") + W_q = torch.randint(0, int(2**4), size=(N, K), dtype=quant_dtype, device=_DEVICE) - scales = torch.arange((N * K) // group_size, dtype=dtype, device="cuda")[:, None] + scales = torch.arange((N * K) // group_size, dtype=dtype, device=_DEVICE)[:, None] zeros = torch.zeros_like(scales) W_dq = ((W_q.reshape(-1, group_size) - zeros) * scales).reshape(N, K) scales = scales.reshape(N, -1) @@ -185,7 +189,7 @@ def _test_mixed_mm( packed_w = pack_2xint4(W_q.T) if transposed: - x = torch.randn(M, N, dtype=dtype, device="cuda") + x = torch.randn(M, N, dtype=dtype, device=_DEVICE) hqq_out = x @ W_dq tt_out = triton_mixed_mm( @@ -203,7 +207,7 @@ def _test_mixed_mm( ) else: - x = torch.randn(M, K, dtype=dtype, device="cuda") + x = torch.randn(M, K, dtype=dtype, device=_DEVICE) hqq_out = x @ W_dq.T tt_out = triton_mixed_mm( diff --git a/test/hqq/test_triton_qkv_fused.py b/test/hqq/test_triton_qkv_fused.py index 015c0a2f05..8a06b14add 100644 --- a/test/hqq/test_triton_qkv_fused.py +++ b/test/hqq/test_triton_qkv_fused.py @@ -20,6 +20,9 @@ import torch from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() torch.manual_seed(0) # N, K = shape @@ -54,7 +57,7 @@ def _arg_to_id(arg): def quantize_helper( - weight_shape, quant_config, dtype, device="cuda", quant_dtype=torch.uint8 + weight_shape, quant_config, dtype, device=_DEVICE, quant_dtype=torch.uint8 ): N, K = weight_shape linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) @@ -120,7 +123,7 @@ def test_mixed_mm( transposed, kernel_type, seqlen=16, - device="cuda", + device=_DEVICE, quant_dtype=torch.uint8, ): """ diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index f7cd9833b6..ee507b15e5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -10,6 +10,7 @@ import logging import os import unittest +import pytest from functools import partial import torch @@ -79,6 +80,7 @@ ) from torchao.testing.utils import skip_if_rocm from torchao.utils import ( + get_available_devices, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, @@ -106,7 +108,7 @@ torch.manual_seed(0) config.cache_size_limit = 100 -COMMON_DEVICES = ["cpu", "cuda"] +COMMON_DEVICES = get_available_devices() COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -392,14 +394,13 @@ def test_swap(self): y = m_copy(x) assert torch.allclose(y_ref, y) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @pytest.mark.parametrize("device", COMMON_DEVICES) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") - def test_weight_t_and_non_t_numerics_match(self): + def test_weight_t_and_non_t_numerics_match(self, device): # verify that numerics match whether weight is stored # in transposed format (for cuBLAS) vs non-transposed format # (for torch.compile) dtype = torch.half - device = "cuda" lin_ref = nn.Linear(32, 16, dtype=dtype, device=device) lin_eager_t = copy.deepcopy(lin_ref) lin_opt_t = copy.deepcopy(lin_eager_t) @@ -598,14 +599,14 @@ def test_per_token_linear_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_per_token_linear_impl("cuda", dtype) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test__int_mm(self): + @pytest.mark.parametrize("device", COMMON_DEVICES) + def test__int_mm(self, device): # TODO(future): figure out what here needs to move to PT core, # if it's not already tested there m, k, n = 32, 32, 16 - x = torch.randint(-128, 127, (m, k), dtype=torch.int8, device="cuda") - w = torch.randint(-128, 127, (k, n), dtype=torch.int8, device="cuda") + x = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=device) + w = torch.randint(-128, 127, (k, n), dtype=torch.int8, device=device) y_ref = torch.matmul(x.float(), w.float()).to(torch.int32) y_raw = safe_int_mm(x, w) @@ -619,13 +620,13 @@ def test__int_mm(self): torch.testing.assert_close(y_ref, y_raw, atol=0, rtol=0) torch.testing.assert_close(y_ref, y_opt, atol=0, rtol=0) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test__int_mm_eager_and_torch_compile_numerics(self): + @pytest.mark.parametrize("device", COMMON_DEVICES) + def test__int_mm_eager_and_torch_compile_numerics(self, device): def __int_mm_ref(x, w): x = x.cpu().to(torch.int32) w = w.cpu().to(torch.int32) y = torch.matmul(x, w) - return y.cuda() + return y.to(device) shapes = ( # minimal test shape @@ -653,8 +654,8 @@ def wrap_torch_int_mm(x, w): wrap_torch_int_mm, mode="max-autotune" ) - x = torch.randint(-128, 127, x_shape, dtype=torch.int8, device="cuda") - w = torch.randint(-128, 127, w_shape, dtype=torch.int8, device="cuda") + x = torch.randint(-128, 127, x_shape, dtype=torch.int8, device=device) + w = torch.randint(-128, 127, w_shape, dtype=torch.int8, device=device) z_ref = __int_mm_ref(x, w) z_eager = wrap_torch_int_mm(x, w) @@ -754,6 +755,7 @@ def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype) test_dtype=dtype, ) + @pytest.mark.parametrize("device", COMMON_DEVICES) @run_supported_device_dtype def _test_lin_weight_subclass_impl( self, @@ -763,8 +765,6 @@ def _test_lin_weight_subclass_impl( test_dtype=torch.bfloat16, test_shape=(32, 64, 32), ): - if not "cuda" in test_device: - self.skipTest("test requires cuda") with torch.no_grad(): m, k, n = test_shape x = torch.randn(m, k, device=test_device, dtype=test_dtype) @@ -1240,15 +1240,8 @@ def test_weight_only_groupwise_embedding_quant(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_weight_only_quant_force_mixed_mm(self, device, dtype): undo_recommended_configs() - if device != "cuda": - self.skipTest( - f"weight_only_quant_force_mixed_mm can't be constructed on {device}" - ) - if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): - self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config mixed_mm_key, mixed_mm_val = ( @@ -1276,14 +1269,9 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): self.assertGreaterEqual(sqnr, 38) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_weight_only_quant_use_mixed_mm(self, device, dtype): undo_recommended_configs() - if device != "cuda": - self.skipTest( - f"weight_only_quant_force_mixed_mm can't be constructed on {device}" - ) - if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16 and torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") torch.manual_seed(0) from torch._inductor import config @@ -1414,17 +1402,17 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly." ) - def test_fullgraph(self): - lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) + @pytest.mark.parametrize("device", COMMON_DEVICES) + def test_fullgraph(self, device): + lin_fp16 = nn.Linear(32, 16, device=device, dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( lin_fp16, alpha=0.25 ) - x0 = torch.randn(17, 1, 32, device="cuda", dtype=torch.float16) + x0 = torch.randn(17, 1, 32, device=device, dtype=torch.float16) # calibrate _ = lin_smooth(x0) @@ -1465,7 +1453,7 @@ def test_shape_logger(self): class SmoothquantIntegrationTest(unittest.TestCase): @torch.no_grad() - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @pytest.mark.parametrize("device", COMMON_DEVICES) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_non_dynamically_quantizable_linear(self): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): @@ -1475,10 +1463,10 @@ def test_non_dynamically_quantizable_linear(self): torch.nn.modules.linear.NonDynamicallyQuantizableLinear(32, 32), torch.nn.ReLU(), ) - .to("cuda") + .to(device) .to(torch.bfloat16) ) - example_input = torch.randn(32, 32, device="cuda", dtype=torch.bfloat16) + example_input = torch.randn(32, 32, device=device, dtype=torch.bfloat16) ref = model(example_input) swap_linear_with_smooth_fq_linear(model) model(ref) @@ -1561,12 +1549,11 @@ class TestAutoQuant(unittest.TestCase): ], ) ) + @pytest.mark.parametrize("device", COMMON_DEVICES) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): undo_recommended_configs() print("(m, k, n): ", (m, k, n)) - if device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): if dtype == torch.bfloat16: self.skipTest("bfloat16 requires sm80+") @@ -1658,8 +1645,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_mha(self, device, dtype): - if device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"autoquant currently does not support {device}") class MHAModel(torch.nn.Module): def __init__(self): @@ -1687,8 +1672,6 @@ def forward(self, x): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): undo_recommended_configs() - if device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): if dtype == torch.bfloat16: self.skipTest("bfloat16 requires sm80+") @@ -1737,8 +1720,6 @@ def test_autoquant_manual(self, device, dtype): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() - if device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): if dtype == torch.bfloat16: self.skipTest("bfloat16 requires sm80+") @@ -1784,8 +1765,6 @@ def forward(self, x, y): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): undo_recommended_configs() - if device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"autoquant currently does not support {device}") if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): if dtype == torch.bfloat16: self.skipTest("bfloat16 requires sm80+") @@ -1811,8 +1790,7 @@ def forward(self, x): assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) model(x_in) - @parameterized.expand(list(itertools.product(["cuda"], COMMON_DTYPES))) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand(COMMON_DEVICE_DTYPE) def test_autoquant_min_sqnr(self, device, dtype): m, k, n = 128, 128, 128 example_input = torch.randn(m, k, device=device, dtype=dtype) @@ -1833,12 +1811,12 @@ def test_autoquant_min_sqnr(self, device, dtype): # setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @pytest.mark.parametrize("device", COMMON_DEVICES) @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." ) - def test_autoquant_hp_float(self): - device = "cuda" + def test_autoquant_hp_float(self, device): + device = device dtype = torch.float32 m, k, n = 128, 128, 128 example_input = torch.randn(m, k, device=device, dtype=dtype) @@ -1866,7 +1844,7 @@ def test_autoquant_hp_float(self): self.assertGreater(compute_error(out, ref), 40) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @pytest.mark.parametrize("device", COMMON_DEVICES) @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." ) @@ -1904,7 +1882,7 @@ def test_autoquant_int4wo(self, device, dtype): self.assertGreater(compute_error(ref, out), 20) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") + @unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90") @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." ) @@ -1948,7 +1926,7 @@ def test_autoquant_float8(self, device, dtype): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") -@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") +@pytest.mark.parametrize("device", COMMON_DEVICES) @unittest.skip( "AOTI tests are failing right now, repro by commenting out the skip and run:" "python test/integration/test_integration.py -k TestAOTI.test_aoti_06" @@ -1958,7 +1936,7 @@ class TestAOTI(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) def test_aoti(self, api, test_device, test_dtype): - if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": + if api is change_linear_weights_to_int8_dqtensors and test_device == device: self.skipTest( f"{api} in {test_device} is not support for aoti compilation yet" ) @@ -2011,7 +1989,6 @@ def forward(self, x): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") -@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestExport(unittest.TestCase): @parameterized.expand( list( @@ -2094,8 +2071,8 @@ def __init__(self): def forward(self, x): return self.linear(x) - model = SimpleNetwork().eval().cuda() - inp = torch.randn(2, 32).cuda() + model = SimpleNetwork().eval().to(device) + inp = torch.randn(2, 32).to(device) config = Float8DynamicActivationFloat8WeightConfig() quantize_(model, config) @@ -2114,8 +2091,6 @@ class TestUtils(unittest.TestCase): def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not supported yet") - if test_device != "cuda" or not torch.cuda.is_available(): - self.skipTest(f"{api} currently does not support {test_device}") k, n = 1024, 1024 model = ( torch.nn.Sequential( @@ -2165,6 +2140,9 @@ def run_benchmark_model(self, device): def test_benchmark_model_cuda(self): assert self.run_benchmark_model("cuda") is not None + def test_benchmark_model_xpu(self): + assert self.run_benchmark_model("xpu") is not None + def test_benchmark_model_cpu(self): assert self.run_benchmark_model("cpu") is not None diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 996dfa5aa2..029e72a8df 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -9,14 +9,23 @@ import logging import os import unittest +import itertools import torch from parameterized import parameterized -from torchao.utils import is_sm_at_least_90 +from torchao.utils import ( + is_sm_at_least_90, + get_available_devices +) logging.basicConfig(level=logging.INFO) +COMMON_DEVICES = get_available_devices() + +COMMON_DTYPES = [torch.float16, torch.bfloat16] + +COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() class TestQuantFlow(unittest.TestCase): def setUp(self): @@ -25,15 +34,7 @@ def setUp(self): def tearDown(self): del os.environ["TORCHAO_AUTOTUNER_ENABLE"] - @parameterized.expand( - [ - ("cuda", torch.bfloat16), - # TODO: ("cpu", torch.bfloat16), - ("cuda", torch.float16), - # TODO: ("cpu", torch.float16), - ] - ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parameterized.expand(COMMON_DEVICE_DTYPE) def test_int_mm(self, device, dtype): from torchao.kernel import intmm @@ -49,13 +50,8 @@ def test_int_mm(self, device, dtype): assert out32_2.dtype == out32_1.dtype torch.testing.assert_allclose(out32_1, out32_2) - @parameterized.expand( - [ - ("cuda", torch.bfloat16), - ("cuda", torch.float16), - ] - ) - @unittest.skipIf(not is_sm_at_least_90(), "Needs H100") + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm @@ -68,14 +64,7 @@ def test_int_mm_float8(self, device, dtype): out32_1 = intmm.safe_int_mm(x_float8, w_float8) assert out32_1.dtype == torch.int32 - @parameterized.expand( - [ - ("cuda", torch.bfloat16), - ("cpu", torch.bfloat16), - ("cuda", torch.float16), - ("cpu", torch.float16), - ] - ) + @parameterized.expand(COMMON_DEVICE_DTYPE) def test_int_scaled_mm(self, device, dtype): if device == "cuda" and not torch.cuda.is_available(): self.skipTest(f"{device} not available") diff --git a/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py index e8e855232c..754e20e15e 100644 --- a/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py +++ b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py @@ -24,7 +24,13 @@ torch_blockwise_scale_weight_quant, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import is_sm_at_least_90 +from torchao.utils import ( + is_sm_at_least_90, + auto_detect_device, +) + +_DEVICE = [auto_detect_device()] +print(11111111111111111111111111111, _DEVICE) BLOCKWISE_SIZE_MNK = [ (128, 128, 128), @@ -37,19 +43,18 @@ (67, 6656, 1408), ] - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("device", _DEVICE) +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.skipif( version.parse(triton.__version__) < version.parse("3.3.0"), reason="Triton version < 3.3.0, test skipped", ) @pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype): +def test_blockwise_fp8_gemm_1x128_128x128(device, M, N, K, dtype): # Simulate output = input @ weight.T - A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") - B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(N, K, dtype=torch.bfloat16, device=device) C = A @ B.T A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype) B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype) @@ -60,19 +65,18 @@ def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype): min_sqnr = 28.0 assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("device", _DEVICE) +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.skipif( version.parse(triton.__version__) < version.parse("3.3.0"), reason="Triton version < 3.3.0, test skipped", ) @pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype): +def test_blockwise_fp8_gemm_1x128_128x1(device, M, N, K, dtype): # Simulate grad_weight = grad_output_t @ input - A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda") - B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda") + A = torch.randn(K, M, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) C = A.T @ B A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype) B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype) @@ -86,12 +90,11 @@ def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype): min_sqnr = 28.0 assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif(torch.cuda.is_available and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) -def test_triton_quantize_fp8_act_quant_lhs(block_size): - device = "cuda" +def test_triton_quantize_fp8_act_quant_lhs(device, block_size): M, K = 4096, 1024 x = torch.randn(M, K, device=device) @@ -133,12 +136,11 @@ def test_triton_quantize_fp8_act_quant_lhs(block_size): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) -def test_triton_quantize_fp8_act_quant_rhs(block_size: int): - device = "cuda" +def test_triton_quantize_fp8_act_quant_rhs(device, block_size: int): M, K = 4096, 1024 x = torch.randn(M, K, device=device) @@ -180,13 +182,12 @@ def test_triton_quantize_fp8_act_quant_rhs(block_size: int): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) @pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) -def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): - device = "cuda" +def test_triton_quantize_fp8_act_quant_transposed_lhs(device, M, K, block_size: int): x = torch.randn(M, K, device=device) # Set one scaling block to 0s, so if nan guards/EPS are not applied, the @@ -229,13 +230,12 @@ def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) @pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) -def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): - device = "cuda" +def test_triton_quantize_fp8_weight_quant_rhs(device, M, K, block_size: int): x = torch.randn(M, K, device=device) # Set one scaling block to 0s, so if nan guards/EPS are not applied, the @@ -275,12 +275,11 @@ def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", ) - +@pytest.mark.parametrize("device", _DEVICE) @skip_if_rocm("ROCm not supported") -@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif(torch.cuda.is_available() and not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") @pytest.mark.parametrize("block_size", [128, 256]) -def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int): - device = "cuda" +def test_triton_quantize_fp8_weight_quant_transposed_rhs(device, block_size: int): M = 512 K = 2048 x = torch.randn(M, K, device=device) diff --git a/test/prototype/blockwise_fp8_training/test_blockwise_linear.py b/test/prototype/blockwise_fp8_training/test_blockwise_linear.py index fdb1ad42f5..b44afaa331 100644 --- a/test/prototype/blockwise_fp8_training/test_blockwise_linear.py +++ b/test/prototype/blockwise_fp8_training/test_blockwise_linear.py @@ -12,7 +12,7 @@ from torchao.utils import is_sm_at_least_90 triton = pytest.importorskip("triton", reason="Triton required to run this test") -if not is_sm_at_least_90(): +if torch.cuda.is_available and not is_sm_at_least_90(): pytest.skip("This test requires SM90 or higher", allow_module_level=True) @@ -21,8 +21,10 @@ torch.random.manual_seed(0) +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("in_features", [4096]) @pytest.mark.parametrize("out_features", [128256]) @pytest.mark.parametrize("batch_size", [1, 8]) @@ -40,12 +42,12 @@ def test_blockwise_quant_linear_fwd_bwd( in_features=in_features, out_features=out_features, bias=False, - ).cuda() + ).to(_DEVICE) layer_test = Float8BlockwiseLinear.from_float(copy.deepcopy(layer_ref)) # Create input tensor - x_test = torch.randn(batch_size, 256, in_features).cuda().requires_grad_(True) + x_test = torch.randn(batch_size, 256, in_features).to(_DEVICE).requires_grad_(True) x_ref = x_test.clone().detach().requires_grad_(True) # Forward pass diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index ed68e8fa23..45657c6ddb 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -9,16 +9,6 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -# We need to skip before doing any imports which would use triton, since -# triton won't be available on CPU builds and torch < 2.5 -if not ( - TORCH_VERSION_AT_LEAST_2_5 - and torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9 -): - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, @@ -28,15 +18,19 @@ _to_2d_jagged_float8_tensor_colwise, _to_2d_jagged_float8_tensor_rowwise, ) -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import( + skip_if_rocm, +) +from torchao.utils import auto_detect_device +_DEVICE = auto_detect_device() @skip_if_rocm("ROCm enablement in progress") @pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): # tests case where rowwise scales are computed for multiple distinct subtensors, # with end boundary of each group is determine by their end column indexes (offsets). - device = "cuda" + device = _DEVICE m, k, n_groups = 256, 256, 4 x = torch.randn(m, k * n_groups, device=device) colwise_offs = torch.arange(k, k * n_groups + 1, k, device=device) @@ -64,7 +58,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: bool): # tests case where colwise scales are computed for multiple distinct subtensors, # with end boundary of each group is determine by their end row indexes (offsets). - device = "cuda" + device = _DEVICE m, k, n_groups = 256, 256, 4 x = torch.randn(m * n_groups, k, device=device).t().contiguous().t() rowwise_offs = torch.arange(m, m * n_groups + 1, m, device=device) diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 426e88b534..0e059ae8a3 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -14,8 +14,6 @@ # triton won't be available on CPU builds and torch < 2.5 if not ( TORCH_VERSION_AT_LEAST_2_7 - and torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9 ): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -40,12 +38,14 @@ ) from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.testing.utils import skip_if_rocm +from torchao.utils import auto_detect_device +_DEVICE = auto_detect_device() @skip_if_rocm("ROCm not supported") def test_valid_scaled_grouped_mm_2d_3d(): out_dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE m, n, k, n_groups = 16, 32, 16, 4 a = torch.randn( m * n_groups, @@ -61,7 +61,7 @@ def test_valid_scaled_grouped_mm_2d_3d(): device=device, dtype=torch.bfloat16, ) - offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + offs = torch.arange(m, n_groups * m + 1, m, device=_DEVICE, dtype=torch.int32) # b must be transposed and in column major format. b_t = b.contiguous().transpose(-2, -1).requires_grad_(True) @@ -109,7 +109,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k): if n % 16 == 0 and k % 16 == 0: return out_dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE n_groups = 4 a = torch.randn( m * n_groups, @@ -131,7 +131,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k): b_t = b.transpose(-2, -1) b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1) - offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + offs = torch.arange(m, n_groups * m + 1, m, device=_DEVICE, dtype=torch.int32) # Compute output. with pytest.raises(AssertionError): @@ -229,8 +229,8 @@ def compute_reference_forward( @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) @pytest.mark.parametrize("num_experts", (1, 8, 16)) def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") - w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE) + w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device=_DEVICE) offs = generate_jagged_offs(num_experts, M) x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone() @@ -263,9 +263,9 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts): def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts): # Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x block_size = 32 - grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + grad_out = torch.randn(M, N, dtype=torch.bfloat16, device=_DEVICE) grad_out_t = grad_out.t().contiguous() - x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, N, dtype=torch.bfloat16, device=_DEVICE) offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone() @@ -312,9 +312,9 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): ) block_size = 32 - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE, requires_grad=True) w_t = torch.randn( - num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True + num_experts, K, N, dtype=torch.bfloat16, device=_DEVICE, requires_grad=True ) offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) x_ref, w_t_ref, offs_ref = ( diff --git a/test/prototype/moe_training/test_tp.py b/test/prototype/moe_training/test_tp.py index 46ba544791..420dfa9a59 100644 --- a/test/prototype/moe_training/test_tp.py +++ b/test/prototype/moe_training/test_tp.py @@ -34,12 +34,16 @@ allow_module_level=True, ) +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() # this feature requires CUDA and SM89+ -if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): - pytest.skip( - "CUDA not available or compute capability < 8.9", allow_module_level=True - ) +if torch.cuda.is_available(): + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) from torchao.float8.float8_utils import compute_error from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig @@ -72,7 +76,6 @@ ], ) def test_moe_float8_training_tp(target_fqns: list[str]): - assert torch.cuda.is_available() # setup distributed for tp mesh = setup_distributed() @@ -85,10 +88,10 @@ def test_moe_float8_training_tp(target_fqns: list[str]): vocab_size=1024, ) init_std = 0.02 - device = torch.device("cuda") + device = torch.device(_DEVICE) # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + ref_model = MoE(model_args).to(torch.bfloat16).to(_DEVICE) torch.manual_seed(1) ref_model.init_weights(init_std, device) @@ -184,10 +187,13 @@ def setup_distributed(): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) dist.init_process_group("nccl", rank=rank, world_size=world_size) - device_mesh = init_device_mesh("cuda", (world_size,)) + device_mesh = init_device_mesh(_DEVICE, (world_size,)) # seed must be the same in all processes torch.manual_seed(1) - torch.cuda.set_device(rank) + if _DEVICE == "cuda": + torch.cuda.set_device(rank) + elif _DEVICE == "xpu": + torch.xpu.set_device(rank) return device_mesh diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 9a68542d88..a2cba69002 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -6,10 +6,11 @@ from torch.nn import functional as F # this feature requires CUDA and SM89+ -if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): - pytest.skip( - "CUDA not available or compute capability < 8.9", allow_module_level=True - ) +if torch.cuda.is_available(): + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) from torchao.float8.float8_utils import compute_error from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig @@ -17,6 +18,10 @@ from .testing_utils import _validate_model_conversion +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + # this test requires torchtitan try: from torchtitan.experiments.llama4.model.args import TransformerModelArgs @@ -42,10 +47,10 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool): dim=256, ) init_std = 0.02 - device = torch.device("cuda") + device = torch.device(_DEVICE) # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + ref_model = MoE(model_args).to(torch.bfloat16).to(_DEVICE) torch.manual_seed(42) ref_model.init_weights(init_std, device) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 6b0aab129c..e4875cc797 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -45,11 +45,16 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( + get_available_devices, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, ) +from torchao.utils import get_available_devices + + +_DEVICES = get_available_devices() torch.manual_seed(0) if not TORCH_VERSION_AT_LEAST_2_8: @@ -327,7 +332,6 @@ def test_fp4_pack_unpack(): assert torch.all(orig_vals_dq == orig_vals) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0") def test_fp4_triton_unscaled_cast(): @@ -337,7 +341,6 @@ def test_fp4_triton_unscaled_cast(): assert torch.all(torch.eq(f32_ref, f32_triton)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0") def test_fp4_triton_scaled_cast(): @@ -403,18 +406,7 @@ def test_fp6_values(dtype_name): torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0) -@pytest.mark.parametrize( - "device", - [ - "cpu", - pytest.param( - "cuda", - marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" - ), - ), - ], -) +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize( "f32_val,f6_e3m2_enc", [ @@ -433,12 +425,10 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device): assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_fp6_e2m3_pack_unpack(): - orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to( - "cuda" - ) +def test_fp6_e2m3_pack_unpack(device): + orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(device) orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals) orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) @@ -448,12 +438,10 @@ def test_fp6_e2m3_pack_unpack(): assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_fp6_e3m2_pack_unpack(): - orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to( - "cuda" - ) +def test_fp6_e3m2_pack_unpack(device): + orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(device) orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals) orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) @@ -471,14 +459,14 @@ def test_fp6_e3m2_pack_unpack(): @pytest.mark.parametrize("M", (256, 2048)) @pytest.mark.parametrize("K", (256, 2048)) def test_triton_mxfp8_dim1_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize( "shape", [ @@ -492,8 +480,8 @@ def test_triton_mxfp8_dim1_randn(M, K): (128, 1), ], ) -def test_rearrange(shape): - scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) +def test_rearrange(device, shape): + scales = torch.randint(256, size=shape, device=device, dtype=torch.uint8) eager = to_blocked(scales, False) triton = to_blocked(scales, True) torch.testing.assert_close(eager, triton, atol=0, rtol=0) @@ -519,7 +507,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): # Use disinct incrementing values from 0 to M*K-1 to make debugging easier. x = ( - torch.arange(0, M * K, dtype=input_dtype, device="cuda") + torch.arange(0, M * K, dtype=input_dtype, device=device) .reshape(M, K) .contiguous() ) @@ -557,7 +545,7 @@ def test_cuda_mx_dim0_not_supported(): M, K = 64, 64 block_size = 32 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=device) .reshape(M, K) .contiguous() ) @@ -580,7 +568,7 @@ def test_cuda_mx_dim1_invalid_block_size(): M, K = 64, 64 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=device) .reshape(M, K) .contiguous() ) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 67ac9e7a61..fbd5103f72 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -30,6 +30,10 @@ is_sm_at_least_100, ) +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + torch.manual_seed(2) if not TORCH_VERSION_AT_LEAST_2_8: @@ -66,7 +70,6 @@ def run_around_tests(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", elem_dtypes) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)]) @@ -99,7 +102,7 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho grad_shape[-1] = 256 m = nn.Sequential( - nn.Linear(256, 256, bias=bias, device="cuda", dtype=torch.bfloat16), + nn.Linear(256, 256, bias=bias, device=_DEVICE, dtype=torch.bfloat16), ) m_mx = copy.deepcopy(m) config = MXLinearConfig( @@ -112,10 +115,10 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho quantize_(m_mx, config) x_ref = torch.randn( - *input_shape, device="cuda", dtype=torch.bfloat16 + *input_shape, device=_DEVICE, dtype=torch.bfloat16 ).requires_grad_() x = copy.deepcopy(x_ref) - g = torch.randn(*grad_shape, device="cuda") + g = torch.randn(*grad_shape, device=_DEVICE) y_ref = m(x_ref) y_mx = m_mx(x) @@ -139,9 +142,8 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho assert x_g_sqnr >= 8.0 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" + torch.cuda.is_available() and not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" ) @pytest.mark.parametrize( "recipe_name", @@ -154,11 +156,11 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): M, K, N = mkn - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_() + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE).requires_grad_() x_copy = copy.deepcopy(x) - g = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + g = torch.randn(M, N, device=_DEVICE, dtype=torch.bfloat16) m_emulated = nn.Sequential( - nn.Linear(K, N, bias=False, device="cuda", dtype=torch.bfloat16), + nn.Linear(K, N, bias=False, device=_DEVICE, dtype=torch.bfloat16), ) m_real = copy.deepcopy(m_emulated) @@ -188,26 +190,24 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): # TODO(future): enable compile support -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): input_shape = (16, 4) grad_shape = (16, 8) elem_dtype = torch.float8_e4m3fn m = nn.Sequential( - nn.Linear(4, 8, bias=True, device="cuda"), - nn.Linear(8, 8, bias=True, device="cuda"), + nn.Linear(4, 8, bias=True, device=_DEVICE), + nn.Linear(8, 8, bias=True, device=_DEVICE), ) config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype) quantize_(m, config=config) - x = torch.randn(*input_shape, device="cuda").requires_grad_() - g = torch.randn(*grad_shape, device="cuda") + x = torch.randn(*input_shape, device=_DEVICE).requires_grad_() + g = torch.randn(*grad_shape, device=_DEVICE) y = torch.utils.checkpoint.checkpoint(m, x, use_reentrant=False) y.backward(g) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize( "recipe_name", @@ -265,7 +265,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice): input_shape = (M, K) grad_shape = (M, N) m_mx = nn.Sequential( - nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype), + nn.Linear(K, N, bias=bias, device=_DEVICE, dtype=hp_dtype), ) config = MXLinearConfig.from_recipe_name(recipe_name) config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice @@ -274,9 +274,9 @@ def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice): m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") - x_ref = torch.randn(*input_shape, device="cuda", dtype=hp_dtype).requires_grad_() + x_ref = torch.randn(*input_shape, device=_DEVICE, dtype=hp_dtype).requires_grad_() x = copy.deepcopy(x_ref) - g = torch.randn(*grad_shape, device="cuda", dtype=hp_dtype) + g = torch.randn(*grad_shape, device=_DEVICE, dtype=hp_dtype) y_ref = m_mx(x_ref) y = m_mx_c(x) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 46380cfb55..ad31390bd0 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -13,17 +13,20 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( + auto_detect_device, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, ) +_DEVICE = auto_detect_device() + if not TORCH_VERSION_AT_LEAST_2_8: pytest.skip("Unsupported PyTorch version", allow_module_level=True) def run_matrix_test(M: int, K: int, N: int, format) -> float: dtype = torch.bfloat16 - device = torch.device("cuda") + device = torch.device(_DEVICE) a = torch.rand((M, K), dtype=dtype, device=device) b = torch.rand((N, K), dtype=dtype, device=device) @@ -57,9 +60,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: return compute_error(out_hp, out).item() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" + torch.cuda.is_available() and not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" ) @pytest.mark.parametrize( "size", diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 6fe91a379f..e92bab3e2e 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -25,11 +25,14 @@ ) from torchao.quantization.utils import compute_error from torchao.utils import ( + auto_detect_device, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, ) +_DEVICE = auto_detect_device() + torch.manual_seed(2) if not TORCH_VERSION_AT_LEAST_2_8: @@ -79,35 +82,31 @@ def assert_sqnr_gt_threshold(orig, new, threshold): assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_hello_world(elem_dtype): - data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + data = torch.randn(8, 8, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode]) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_realistic_numerics(elem_dtype, scale_calculation_mode): - data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + data = torch.randn(128, 128, device=_DEVICE, dtype=torch.bfloat16) block_size = 32 _test_mx(data, elem_dtype, block_size, scale_calculation_mode) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): - data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.zeros(4, 4, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_some_zeros(elem_dtype): - data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.randn(4, 4, device=_DEVICE, dtype=torch.bfloat16) data[0, :] = 0.0 data[:, 2] = 0.0 block_size = 4 @@ -116,7 +115,6 @@ def test_some_zeros(elem_dtype): # TODO(future PR): fix and reenable this test @pytest.mark.skip(reason="does not pass on B200 yet") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_to_mx_rceil(): # nan # fmt: off @@ -329,7 +327,6 @@ def test_to_mx_rceil(): torch.testing.assert_close(data_mx._data, ground_truth_fp8) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_exponent_nan_in(elem_dtype): """ @@ -337,7 +334,7 @@ def test_exponent_nan_in(elem_dtype): value is set to is NaN """ tensor_hp = torch.tensor( - [float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16 + [float("nan"), 1, 2, 3, 4, 5, 6, 7], device=_DEVICE, dtype=torch.bfloat16 ) block_size = 4 tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) @@ -345,7 +342,6 @@ def test_exponent_nan_in(elem_dtype): assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("pack_fp6", [False, True]) def test_exponent_nan_out(elem_dtype, pack_fp6): @@ -356,25 +352,25 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): pytest.skip("invalid configuration") scale_e8m0 = torch.tensor( - [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda" + [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device=_DEVICE ) block_size = 4 if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device=_DEVICE ) # noqa: E501 elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE ) # noqa: E501 if pack_fp6: data_bits = data_bits.reshape(-1, block_size) data_bits = pack_uint6(data_bits) elif elem_dtype == torch.float4_e2m1fn_x2: data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE ) # noqa: E501 data_bits = pack_uint4(data_bits) else: @@ -396,7 +392,6 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): assert not torch.any(torch.isnan(tensor_hp.flatten()[4:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_ranks(elem_dtype): """ @@ -405,11 +400,11 @@ def test_ranks(elem_dtype): B = 4 shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4)) for s in shapes: - tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(*s, device=_DEVICE, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("B", [1, 4, 32]) def test_block_sizes(elem_dtype, B): @@ -420,11 +415,10 @@ def test_block_sizes(elem_dtype, B): pytest.skip("unsupported configuration") elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: pytest.skip("unsupported configuration") - tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(B, device=_DEVICE, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("fp4_triton", [False, True]) def test_transpose(elem_dtype, fp4_triton): @@ -436,7 +430,7 @@ def test_transpose(elem_dtype, fp4_triton): M, K = 128, 256 block_size = 32 - tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) tensor_mx = MXTensor.to_mx( tensor_hp, elem_dtype, @@ -452,20 +446,18 @@ def test_transpose(elem_dtype, fp4_triton): torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_view(elem_dtype): - x = torch.randn(1, 2, 4, device="cuda") + x = torch.randn(1, 2, 4, device=_DEVICE) block_size = 4 x_mx = MXTensor.to_mx(x, elem_dtype, block_size) x_mx_2 = x_mx.view(2, 4) # noqa: F841 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) @pytest.mark.parametrize("pack_fp6", [False, True]) def test_fp6_packing(elem_dtype, pack_fp6): - x = torch.randn(1, 2, 4, device="cuda") + x = torch.randn(1, 2, 4, device=_DEVICE) block_size = 4 x_mx = MXTensor.to_mx(x, elem_dtype, block_size, pack_fp6=pack_fp6) if pack_fp6: @@ -476,7 +468,6 @@ def test_fp6_packing(elem_dtype, pack_fp6): assert x_mx._data.shape == expected_packed_shape -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("all_zeros", [False, True]) @@ -491,9 +482,9 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): shape = 4, 8 if not all_zeros: - x = torch.randn(*shape, dtype=hp_dtype, device="cuda") + x = torch.randn(*shape, dtype=hp_dtype, device=_DEVICE) else: - x = torch.zeros(*shape, dtype=hp_dtype, device="cuda") + x = torch.zeros(*shape, dtype=hp_dtype, device=_DEVICE) block_size = 4 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) @@ -532,7 +523,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", @@ -544,14 +534,13 @@ def test_to_mx_inductor_single_kernel(): """ # TODO(future PR): add fp4 and fp6 here # TODO(#1773): add swizzled scale format here - x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda") + x = torch.randn(2048, 2048, dtype=torch.bfloat16, device=_DEVICE) block_size = 32 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size) FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", @@ -568,7 +557,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * max_val, ], dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, ) # create example data outside the representable range @@ -578,7 +567,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * (max_val * 2), ], dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, ) # verify that in eager mode PyTorch casting to float8 is unsaturated @@ -612,7 +601,6 @@ def to_f8(x): (torch.bfloat16, (64, 128), True), ], ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" ) @@ -622,7 +610,7 @@ def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale): per_tensor_amax_to_scale, ) - x = torch.randn(shape, dtype=dtype, device="cuda") + x = torch.randn(shape, dtype=dtype, device=_DEVICE) if use_per_tensor_scale: tensor_amax = torch.max(torch.abs(x)) scale = per_tensor_amax_to_scale(tensor_amax) @@ -680,7 +668,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): ], ) @pytest.mark.parametrize( - "use_triton_kernel", [False, True] if torch.cuda.is_available() else [False] + "use_triton_kernel", [False, True] if torch.cuda.is_available() or torch.xpu.is_available else [False] ) @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" @@ -689,7 +677,7 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): from torchao.prototype.mx_formats.utils import from_blocked, to_blocked rows, cols = shape - device = "cuda" if torch.cuda.is_available() else "cpu" + device = _DEVICE original = torch.randint(0, 255, (rows, cols), device=device, dtype=torch.uint8) @@ -718,7 +706,6 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): """ Test that NVFP4Tensor can be constructed with swizzled scales and @@ -727,7 +714,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor M, K = shape - data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales) assert tensor._is_swizzled_scales == is_swizzled_scales @@ -753,7 +740,6 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"), ], ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" ) @@ -772,7 +758,7 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): # For column slicing, need multiples of 64 columns for alignment M, K = 128, 4096 - data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) assert tensor._is_swizzled_scales == True @@ -848,7 +834,6 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): ), ], ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" ) @@ -859,7 +844,7 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor M, K = 256, 4096 - data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) with pytest.raises(RuntimeError, match=expected_error): @@ -869,7 +854,6 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er _ = tensor[:, slice_spec] -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" ) @@ -880,7 +864,7 @@ def test_nvfp4_swizzled_scales_view_semantics(): from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor M, K = 256, 4096 - data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) # Test row slicing (should maintain views) @@ -896,7 +880,6 @@ def test_nvfp4_swizzled_scales_view_semantics(): assert full_width_slice._data.data_ptr() == tensor._data.data_ptr() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" ) @@ -907,7 +890,7 @@ def test_nvfp4_swizzled_scales_serialization(): from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor M, K = 32, 64 - data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) # Create tensor with swizzled scales original_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) @@ -938,7 +921,6 @@ def test_nvfp4_swizzled_scales_serialization(): torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" ) @@ -949,7 +931,7 @@ def test_nvfp4_swizzled_scales_get_scales_method(): from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor M, K = 32, 64 - data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) # Create tensors with both storage methods regular_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=False) @@ -966,7 +948,6 @@ def test_nvfp4_swizzled_scales_get_scales_method(): assert swizzled_scales.shape == expected_shape -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}" ) @@ -988,7 +969,7 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): ) torch.manual_seed(42) - x = torch.randn(M, N, dtype=dtype, device="cuda") + x = torch.randn(M, N, dtype=dtype, device=_DEVICE) per_tensor_scale = None if use_per_tensor_scale: diff --git a/test/prototype/test_autoround.py b/test/prototype/test_autoround.py index 483704a28c..24d9628501 100644 --- a/test/prototype/test_autoround.py +++ b/test/prototype/test_autoround.py @@ -25,9 +25,12 @@ prepare_model_for_applying_auto_round_, ) from torchao.prototype.autoround.multi_tensor import MultiTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + get_available_devices +) -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_AVAILABLE_DEVICES = get_available_devices() # Copied from https://github.com/pytorch/ao/pull/721 diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 5538fa513d..6a2616317f 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -18,8 +18,10 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, _is_fbgemm_genai_gpu_available, + auto_detect_device, ) +_DEVICE = auto_detect_device() class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -29,7 +31,7 @@ def __init__(self, m=512, n=256, k=128): self.linear3 = torch.nn.Linear(k, 64, bias=False) def example_inputs( - self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda" + self, batch_size, sequence_length=10, dtype=torch.bfloat16, device=_DEVICE ): return [ torch.randn( @@ -45,7 +47,6 @@ def forward(self, x): return x -@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") @unittest.skipIf( not _is_fbgemm_genai_gpu_available(), reason="need to install fbgemm_gpu_genai package", @@ -69,7 +70,7 @@ def test_awq_config(self): AWQConfig(base_config, step="not_supported") def test_awq_functionality(self): - device = "cuda" + device = _DEVICE dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs @@ -118,7 +119,7 @@ def test_awq_functionality(self): assert loss_awq < loss_base def test_awq_loading(self): - device = "cuda" + device = _DEVICE dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs @@ -178,7 +179,7 @@ def test_awq_loading_vllm(self): There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo """ - device = "cuda" + device = _DEVICE dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs diff --git a/test/prototype/test_blockwise_triton.py b/test/prototype/test_blockwise_triton.py index 1c79ed9b23..9d54702936 100644 --- a/test/prototype/test_blockwise_triton.py +++ b/test/prototype/test_blockwise_triton.py @@ -9,6 +9,10 @@ from packaging import version +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + triton = pytest.importorskip("triton", reason="Triton required to run this test") from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( @@ -29,7 +33,6 @@ ] -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize( "dtype", @@ -38,7 +41,7 @@ else [torch.float8_e5m2], ) def test_blockwise_quant_dequant(_, N, K, dtype): - x = torch.randn(N, K).cuda() + x = torch.randn(N, K).to(_DEVICE) qx, s = fp8_blockwise_weight_quant(x, dtype=dtype) x_reconstructed = fp8_blockwise_weight_dequant(qx, s) error = torch.norm(x - x_reconstructed) / torch.norm(x) @@ -47,7 +50,6 @@ def test_blockwise_quant_dequant(_, N, K, dtype): assert error < 0.1, "Quant-Dequant error is too high" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( version.parse(triton.__version__) < version.parse("3.3.0"), reason="Triton version < 3.3.0, test skipped", @@ -60,8 +62,8 @@ def test_blockwise_quant_dequant(_, N, K, dtype): else [torch.float8_e5m2], ) def test_blockwise_fp8_gemm(M, N, K, dtype): - A = torch.randn(M, K).cuda() - B = torch.randn(N, K).cuda() + A = torch.randn(M, K).to(_DEVICE) + B = torch.randn(N, K).to(_DEVICE) C = A @ B.T A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype) B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype) diff --git a/test/prototype/test_codebook_coreml.py b/test/prototype/test_codebook_coreml.py index 0c16de8969..ba33b965cc 100644 --- a/test/prototype/test_codebook_coreml.py +++ b/test/prototype/test_codebook_coreml.py @@ -76,7 +76,6 @@ def test_quantize_api(self): ) assert type(m[0].weight) == CodebookQuantizedTensor - @skip_if_no_cuda() @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "requires 2.6+.") def test_export(self): m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 6ceeb0d795..45d86e35a4 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -48,7 +48,9 @@ check_cpu_version, ) -_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() def split_param_groups(model): diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 264c70abb6..bde635676b 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -40,7 +40,9 @@ ) from torchao.quantization.quant_api import quantize_ -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +from torchao.utils import get_available_devices + +_DEVICES = get_available_devices() def _reset(): @@ -184,12 +186,11 @@ def test_int8_weight_only_training(self, compile, device): ], ) @parametrize("module_swap", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_int8_mixed_precision_training(self, compile, config, module_swap): _reset() bsize = 64 embed_dim = 64 - device = "cuda" + device = _DEVICE linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) @@ -221,7 +222,6 @@ def snr(ref, actual): @pytest.mark.skip("Flaky on CI") @parametrize("compile", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_bitnet_training(self, compile): # reference implementation # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf @@ -246,7 +246,7 @@ def forward(self, x): _reset() bsize = 4 embed_dim = 32 - device = "cuda" + device = _DEVICE # only use 1 matmul shape to reduce triton autotune time model_ref = nn.Sequential( @@ -343,7 +343,7 @@ def _run_subtest(self, args): dropout_p=0, ) torch.manual_seed(42) - base_model = Transformer(model_args).cuda() + base_model = Transformer(model_args).to(_DEVICE) fsdp_model = copy.deepcopy(base_model) quantize_(base_model.layers, quantize_fn) @@ -363,7 +363,7 @@ def _run_subtest(self, args): torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): - inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE) fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() fsdp_loss.backward() @@ -394,7 +394,7 @@ def test_precompute_bitnet_scale(self): precompute_bitnet_scale_for_fsdp, ) - model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda() + model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(_DEVICE) model_fsdp = copy.deepcopy(model) quantize_(model_fsdp, bitnet_training()) fully_shard(model_fsdp) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index a5265f7b1f..b608ce0639 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -23,8 +23,11 @@ ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, + get_available_devices, ) +devices = get_available_devices() + if torch.version.hip is not None: pytest.skip("Skipping the test in ROCm", allow_module_level=True) @@ -56,9 +59,6 @@ def forward(self, x): bias_list = [True, False] alpha_list = [None, 0.5, 0.75] quant_mode_list = ["static", "dynamic"] -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") idtypes = (torch.float, torch.bfloat16, torch.half) if TORCH_VERSION_AT_LEAST_2_5: @@ -71,7 +71,6 @@ def forward(self, x): @pytest.mark.parametrize("quant_mode", quant_mode_list) @pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("idtype", idtypes) -@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it") def test_compute(bias, alpha, quant_mode, device, idtype): class Linear(torch.nn.Module): def __init__(self, bias: bool): diff --git a/test/prototype/test_spinquant.py b/test/prototype/test_spinquant.py index 03f0c34e20..72d89b2ad3 100644 --- a/test/prototype/test_spinquant.py +++ b/test/prototype/test_spinquant.py @@ -9,6 +9,7 @@ from torchao._models.llama.model import Transformer from torchao.prototype.spinquant import apply_spinquant +from torchao.utils import get_available_devices def _init_model(name="7B", device="cpu", precision=torch.bfloat16): model = Transformer.from_name(name) @@ -16,7 +17,7 @@ def _init_model(name="7B", device="cpu", precision=torch.bfloat16): return model.eval() -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_AVAILABLE_DEVICES = get_available_devices() @pytest.mark.parametrize("device", _AVAILABLE_DEVICES) diff --git a/test/prototype/test_structured_sparsifier.py b/test/prototype/test_structured_sparsifier.py index 58b34fcae6..5d638ca68c 100644 --- a/test/prototype/test_structured_sparsifier.py +++ b/test/prototype/test_structured_sparsifier.py @@ -38,15 +38,14 @@ SaliencyPruner, ) +from torchao.utils import get_available_devices + + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) -DEVICES = { - torch.device("cpu"), - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), -} - +DEVICES = get_available_devices() class SimplePruner(BaseStructuredSparsifier): def update_mask(self, module, tensor_name, **kwargs): diff --git a/test/quantization/test_gptq.py b/test/quantization/test_gptq.py index 98760f8cf6..6b275e8a8a 100644 --- a/test/quantization/test_gptq.py +++ b/test/quantization/test_gptq.py @@ -2,13 +2,16 @@ from pathlib import Path import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import ( + TestCase, +) from torchao._models.llama.model import ( ModelArgs, Transformer, prepare_inputs_for_model, ) +from torchao.utils import auto_detect_device from torchao._models.llama.tokenizer import get_tokenizer from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error @@ -18,10 +21,10 @@ torch.manual_seed(0) +_DEVICE = auto_detect_device() class TestGPTQ(TestCase): @unittest.skip("skipping until we get checkpoints for gpt-fast") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_gptq_quantizer_int4_weight_only(self): from torchao._models._eval import ( LMEvalInputRecorder, @@ -30,7 +33,6 @@ def test_gptq_quantizer_int4_weight_only(self): from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer precision = torch.bfloat16 - device = "cuda" checkpoint_path = Path( "../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" ) @@ -80,7 +82,7 @@ def test_gptq_quantizer_int4_weight_only(self): model = quantizer.quantize(model, *inputs).cuda() model.reset_caches() - with torch.device("cuda"): + with torch.device(_DEVICE): model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size) limit = 1 @@ -89,7 +91,7 @@ def test_gptq_quantizer_int4_weight_only(self): tokenizer, model.config.block_size, prepare_inputs_for_model, - device, + _DEVICE, ).run_eval( ["wikitext"], limit, @@ -102,7 +104,6 @@ def test_gptq_quantizer_int4_weight_only(self): class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ import MultiTensor @@ -115,7 +116,6 @@ def test_multitensor_add_tensors(self): self.assertTrue(torch.equal(mt.values[1], tensor2)) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ import MultiTensor @@ -127,7 +127,6 @@ def test_multitensor_pad_unpad(self): self.assertEqual(mt.count, 1) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ import MultiTensor @@ -138,7 +137,6 @@ def test_multitensor_inplace_operation(self): class TestMultiTensorInputRecorder(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_input_recorder(self): from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder @@ -159,7 +157,6 @@ def test_multitensor_input_recorder(self): self.assertTrue(isinstance(MT_input[2][2], MultiTensor)) self.assertEqual(MT_input[3], torch.float) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_gptq_with_input_recorder(self): from torchao.quantization.GPTQ import ( Int4WeightOnlyGPTQQuantizer, @@ -170,7 +167,7 @@ def test_gptq_with_input_recorder(self): config = ModelArgs(n_layer=2) - with torch.device("cuda"): + with torch.device(_DEVICE): model = Transformer(config) model.setup_caches(max_batch_size=2, max_seq_length=100) idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32) @@ -191,7 +188,11 @@ def test_gptq_with_input_recorder(self): args = input_recorder.get_recorded_inputs() - quantizer = Int4WeightOnlyGPTQQuantizer() + if _DEVICE == "xpu": + from torchao.dtypes import Int4XPULayout + quantizer = Int4WeightOnlyGPTQQuantizer(device=torch.device("xpu"), layout=Int4XPULayout()) + else: + quantizer = Int4WeightOnlyGPTQQuantizer() quantizer.quantize(model, *args) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 425b881dba..0385bb0925 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -31,8 +31,11 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_90, + auto_detect_device, ) +_DEVICE = auto_detect_device() + if torch.version.hip is not None: pytest.skip( "ROCm support for MoE quantization is under development", @@ -52,7 +55,7 @@ def _test_impl_moe_quant( base_class=AffineQuantizedTensor, tensor_impl_class=None, dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, fullgraph=False, ): """ @@ -114,8 +117,6 @@ def _test_impl_moe_quant( ] ) def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") @@ -138,10 +139,6 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): ] ) def test_int4wo_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - if not is_sm_at_least_90(): - self.skipTest("Requires CUDA capability >= 9.0") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index c83f64022b..3a5539f710 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -84,11 +84,13 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6, + auto_detect_device, ) # TODO: put this in a common test utils file -_CUDA_IS_AVAILABLE = torch.cuda.is_available() +_GPU_IS_AVAILABLE = True if torch.cuda.is_available() or torch.xpu.is_available() else False +_DEVICE = auto_detect_device() class Sub(torch.nn.Module): def __init__(self): @@ -330,7 +332,7 @@ def _set_ptq_weight( group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), + q_weight.to(_DEVICE), qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight @@ -601,13 +603,13 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 inner_k_tiles = 8 scales_precision = torch.bfloat16 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) x = torch.randn(100, 256, dtype=dtype, device=device) @@ -655,13 +657,13 @@ def test_qat_4w_primitives(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( @@ -702,14 +704,14 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) m = M().to(device).to(dtype) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b9d99e7ac7..f867bbf1db 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -74,8 +74,11 @@ is_sm_at_least_89, is_sm_at_least_90, unwrap_tensor_subclass, + auto_detect_device, ) +_DEVICE = auto_detect_device() + try: import gemlite # noqa: F401 @@ -301,7 +304,7 @@ def api(model): m2.load_state_dict(state_dict) m2 = m2.to(device="cuda") - example_inputs = map(lambda x: x.cuda(), example_inputs) + example_inputs = map(lambda x: x.to(_DEVICE), example_inputs) res = m2(*example_inputs) # TODO: figure out why ROCm has a larger error @@ -339,12 +342,13 @@ def test_8da4w_quantizer_linear_bias(self): m(*example_inputs) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantizer_int4_weight_only(self): from torchao._models._eval import TransformerEvalWrapper from torchao.quantization.linear_quant_modules import Int4WeightOnlyQuantizer precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -361,7 +365,7 @@ def test_quantizer_int4_weight_only(self): quantizer = Int4WeightOnlyQuantizer( groupsize, ) - model = quantizer.quantize(model).cuda() + model = quantizer.quantize(model).to(_DEVICE) result = TransformerEvalWrapper( model, tokenizer, @@ -377,11 +381,12 @@ def test_quantizer_int4_weight_only(self): ) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -410,11 +415,12 @@ def test_eval_wrapper(self): # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path( ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" ) @@ -609,11 +615,15 @@ def test_int8wo_quantized_model_to_device(self): self.assertEqual(cuda_res.cpu(), ref) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_int4wo_quantized_model_to_device(self): # TODO: change initial model to "cpu" - devices = ["cuda", "cuda:0"] + if _DEVICE == "cuda": + devices = ["cuda", "cuda:0"] + elif _DEVICE =="xpu": + devices = ["xpu", "xpu:0"] + for device in devices: m = ToyLinearModel().eval().to(torch.bfloat16).to(device) example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) @@ -627,10 +637,10 @@ def test_int4wo_quantized_model_to_device(self): self.assertEqual(cuda_res.cpu(), ref) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_tensor_subclass_save_load_map_location(self): - m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") + m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device=_DEVICE) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device=_DEVICE) quantize_(m, int8_weight_only()) ref = m(*example_inputs) @@ -643,32 +653,50 @@ def test_quantized_tensor_subclass_save_load_map_location(self): m_copy = ToyLinearModel().eval() m_copy.load_state_dict(state_dict, assign=True) - m_copy.to(dtype=torch.bfloat16, device="cuda") + m_copy.to(dtype=torch.bfloat16, device=_DEVICE) res = m_copy(*example_inputs) self.assertEqual(res, ref) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_model_streaming(self): - def reset_memory(): - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + + def get_max_memory_allocated(device): + if device == "cuda": + return torch.cuda.max_memory_allocated(device) + elif device == "xpu": + return torch.xpu.max_memory_allocated(device) + elif device == "cpu": + return 0 + else: + raise ValueError(f"Unsupported device type: {device}") - reset_memory() + def reset_memory(device): + gc.collect() + if device == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device.index if device.index is not None else None) + elif device == "xpu": + torch.xpu.empty_cache() + elif device == "cpu": + pass + else: + raise ValueError(f"Unsupported device type: {device}") + + reset_memory(_DEVICE) m = ToyLinearModel() - quantize_(m.to(device="cuda"), int8_weight_only()) - memory_baseline = torch.cuda.max_memory_allocated() + quantize_(m.to(device=_DEVICE), int8_weight_only()) + memory_baseline = get_max_memory_allocated(_DEVICE) del m - reset_memory() + reset_memory(_DEVICE) m = ToyLinearModel() - quantize_(m, int8_weight_only(), device="cuda") - memory_streaming = torch.cuda.max_memory_allocated() + quantize_(m, int8_weight_only(), device=_DEVICE) + memory_streaming = get_max_memory_allocated(_DEVICE) for param in m.parameters(): - assert param.is_cuda + assert getattr(param, f'is_{_DEVICE}') self.assertLess(memory_streaming, memory_baseline) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @@ -699,7 +727,7 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") @common_utils.parametrize( "config", [ @@ -744,17 +772,17 @@ def test_workflow_e2e_numerics(self, config): # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, float8_static_activation_float8_weight): - config.scale = config.scale.to("cuda") + config.scale = config.scale.to(_DEVICE) dtype = torch.bfloat16 if isinstance(config, gemlite_uintx_weight_only): dtype = torch.float16 # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=dtype) + x = torch.randn(128, 128, device=_DEVICE, dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).to(_DEVICE).to(dtype) m_q = copy.deepcopy(m_ref) # quantize @@ -767,13 +795,13 @@ def test_workflow_e2e_numerics(self, config): sqnr = compute_error(y_ref, y_q) assert sqnr >= 16.5, f"SQNR {sqnr} is too low" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_module_fqn_to_config_default(self): config1 = Int4WeightOnlyConfig(group_size=32) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"_default": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -781,13 +809,13 @@ def test_module_fqn_to_config_default(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_module_fqn_to_config_module_name(self): config1 = Int4WeightOnlyConfig(group_size=32) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"linear1": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -827,25 +855,25 @@ def test_module_fqn_to_config_embedding_linear(self): assert isinstance(model.emb.weight._layout, QDQLayout) assert isinstance(model.linear.weight, LinearActivationQuantizedTensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_module_fqn_to_config_skip(self): config1 = Int4WeightOnlyConfig(group_size=32) config = ModuleFqnToConfig({"_default": config1, "linear2": None}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) assert not isinstance(model.linear2.weight, AffineQuantizedTensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_int4wo_cuda_serialization(self): config = Int4WeightOnlyConfig(group_size=32) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) # quantize in cuda quantize_(model, config) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) model(*example_inputs) with tempfile.NamedTemporaryFile() as ckpt: # save checkpoint in cuda @@ -854,7 +882,7 @@ def test_int4wo_cuda_serialization(self): # This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253 sd = torch.load(ckpt.name, weights_only=False, map_location="cpu") for k, v in sd.items(): - sd[k] = v.to("cuda") + sd[k] = v.to(_DEVICE) # load state_dict in cuda model.load_state_dict(sd, assign=True) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 12027243a8..eea90a5661 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -36,11 +36,14 @@ check_cpu_version, check_xpu_version, is_fbcode, + auto_detect_device, ) _SEED = 1234 torch.manual_seed(_SEED) +_GPU_IS_AVAILABLE = True if torch.cuda.is_available() or torch.xpu.is_available() else False +_DEVICE = auto_detect_device() # Helper function to run a function twice # and verify that the result is the same. @@ -614,12 +617,10 @@ def test_choose_qparams_tensor_asym_eps(self): eps = torch.finfo(torch.float32).eps self.assertEqual(scale, eps) - @unittest.skipIf( - not torch.cuda.is_available(), "skipping when cuda is not available" - ) + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_get_group_qparams_symmetric_memory(self): """Check the memory usage of the op""" - weight = torch.randn(1024, 1024).to(device="cuda") + weight = torch.randn(1024, 1024).to(device=_DEVICE) original_mem_use = torch.cuda.memory_allocated() n_bit = 4 groupsize = 128 diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index cc8f1179bf..b8b2530330 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -20,15 +20,19 @@ ) from torchao.sparsity import sparsify_ from torchao.sparsity.utils import create_semi_structured_tensor -from torchao.utils import is_sm_at_least_90 +from torchao.utils import ( + is_sm_at_least_90, + auto_detect_device, +) +_DEVICE = auto_detect_device() -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +@unittest.skipIf( torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90") def test_sparse24_sm90_sparsify_identity( M=512, K=1024, fp8=torch.float8_e4m3fn ) -> None: torch.manual_seed(0) - A_sp_ref = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A_sp_ref = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).to(_DEVICE) # Test with act="identity" A_packed_ref, A_mdata_ref = to_sparse_semi_structured_cutlass_sm9x_f8( @@ -50,13 +54,13 @@ def test_sparse24_sm90_sparsify_identity( assert torch.allclose(A_packed.float().sum(), A_packed_ref.float().sum()) -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90") def test_sparse24_sm90_sparsify_identity_scaled( M=512, K=1024, fp8=torch.float8_e4m3fn ) -> None: torch.manual_seed(0) - A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() - A_scale = torch.randn([M, 1], device="cuda", dtype=torch.float32).abs() + 0.1 + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).to(_DEVICE) + A_scale = torch.randn([M, 1], device=_DEVICE, dtype=torch.float32).abs() + 0.1 A_sp_ref = (A_dense / A_scale).bfloat16() A_packed_ref, A_mdata_ref = to_sparse_semi_structured_cutlass_sm9x_f8( @@ -77,10 +81,10 @@ def test_sparse24_sm90_sparsify_identity_scaled( ) -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90") def test_sparse24_sm90_sparsify_srelu(M=512, K=1024, fp8=torch.float8_e4m3fn) -> None: torch.manual_seed(0) - A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).to(_DEVICE) A_sp_ref = (A_dense.float().relu() ** 2).bfloat16() # Test with act="srelu" @@ -102,14 +106,14 @@ def test_sparse24_sm90_sparsify_srelu(M=512, K=1024, fp8=torch.float8_e4m3fn) -> assert (A_packed != A_packed_ref).float().mean().item() < 0.1 -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90") def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024): with torch.no_grad(): torch.manual_seed(0) - input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).to(_DEVICE) # we have to wrap in a sequential block for quantize_ to work properly reference_linear = torch.nn.Sequential( - torch.nn.Linear(K, N, bias=False).cuda().to(torch.bfloat16) + torch.nn.Linear(K, N, bias=False).to(_DEVICE).to(torch.bfloat16) ) reference_linear_copy = copy.deepcopy(reference_linear) @@ -144,13 +148,13 @@ def srelu_linear(x): torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90") def test_sparse24_fp8_sm90_cutlass_gemm_eye( M=512, K=256, dtype=torch.float8_e4m3fn ) -> None: torch.manual_seed(0) - A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).to(_DEVICE) A_aqt = _float8_cutlass_quant(A_dense, dtype) A = A_aqt.tensor_impl.float8_data @@ -179,7 +183,7 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye( ) -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90") def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor( M=512, N=1024, K=256, dtype=torch.float8_e4m3fn ) -> None: @@ -190,10 +194,10 @@ def _to_fp8_rowwise(x: torch.Tensor, dtype): return x, x_scale torch.manual_seed(0) - A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).to(_DEVICE) A, a_scale = _to_fp8_rowwise(A_dense, dtype) - B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16) + B_dense = torch.randn([N, K], device=_DEVICE, dtype=torch.bfloat16) B, b_scale = _to_fp8_rowwise(B_dense, dtype) B = B.T diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 804a585dd8..a768ebc873 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,8 +15,13 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + is_fbcode, + auto_detect_device, +) +_DEVICE = auto_detect_device() class ToyModel(nn.Module): def __init__(self): @@ -33,16 +38,15 @@ def forward(self, x): class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - input = torch.rand((128, 128)).half().cuda() - grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + input = torch.rand((128, 128)).half().to(_DEVICE) + grad = torch.rand((128, 128)).half().to(_DEVICE) + model = ToyModel().half().to(_DEVICE) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): @@ -82,16 +86,15 @@ def test_runtime_weight_sparsification(self): assert not isinstance(mod, SemiSparseLinear) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - input = torch.rand((128, 128)).half().cuda() - grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + input = torch.rand((128, 128)).half().to(_DEVICE) + grad = torch.rand((128, 128)).half().to(_DEVICE) + model = ToyModel().half().to(_DEVICE) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 783de6c6ae..7e0b65c9f6 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -20,15 +20,19 @@ from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + auto_detect_device, +) +_DEVICE = auto_detect_device() class SparseMarlin24(TestCase): def setUp(self): super().setUp() torch.manual_seed(0) - self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda") + self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device=_DEVICE) self.model = ( nn.Sequential( nn.Linear(4096, 21504), @@ -41,7 +45,6 @@ def setUp(self): .cuda() ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_eager(self): apply_fake_sparsity(self.model) @@ -59,7 +62,6 @@ def test_quant_sparse_marlin_layout_eager(self): ) @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_compile(self): apply_fake_sparsity(self.model) @@ -79,7 +81,6 @@ def test_quant_sparse_marlin_layout_compile(self): "Results are not close" ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_pack_unpack_equivalence(self): num_bits = 4 group_size = 128 @@ -93,7 +94,7 @@ def test_pack_unpack_equivalence(self): mapping_type = MappingType.SYMMETRIC scale_dtype = None - w = torch.rand(shape, dtype=torch.float16, device="cuda") + w = torch.rand(shape, dtype=torch.float16, device=_DEVICE) # Inject 2:4 sparsity mask w_24, _ = inject_24(w, *w.shape) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 5e3086c411..11221497b9 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -23,8 +23,11 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + auto_detect_device ) +_DEVICE = auto_detect_device() + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -32,17 +35,16 @@ class TestSemiStructuredSparse(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skip("Temporarily skipping to unpin nightlies") def test_sparse(self): - input = torch.rand((128, 128)).half().cuda() + input = torch.rand((128, 128)).half().to(_DEVICE) model = ( nn.Sequential( nn.Linear(128, 256), nn.Linear(256, 128), ) .half() - .cuda() + .to(_DEVICE) .eval() ) @@ -60,7 +62,6 @@ def test_sparse(self): class TestQuantSemiSparse(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [False]) @unittest.skip("Temporarily skip to unbreak CI") def test_quant_semi_sparse(self, compile): @@ -72,14 +73,14 @@ def test_quant_semi_sparse(self, compile): torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False - input = torch.rand((128, 128)).half().cuda() + input = torch.rand((128, 128)).half().to(_DEVICE) model = ( nn.Sequential( nn.Linear(128, 256), nn.Linear(256, 128), ) .half() - .cuda() + .to(_DEVICE) .eval() ) apply_fake_sparsity(model) @@ -98,20 +99,19 @@ def test_quant_semi_sparse(self, compile): torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse_marlin(self, compile): if not torch.backends.cusparselt.is_available(): self.skipTest("Need cuSPARSELt") - input = torch.rand((256, 256)).half().cuda() + input = torch.rand((256, 256)).half().to(_DEVICE) model = ( nn.Sequential( nn.Linear(256, 1024), nn.Linear(1024, 256), ) .half() - .cuda() + .to(_DEVICE) .eval() ) @@ -136,18 +136,17 @@ class TestBlockSparseWeight(common_utils.TestCase): not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support", ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("input_shape", [1, 1024]) def test_sparse(self, compile, input_shape): - input = torch.rand((input_shape, 1024)).half().cuda() + input = torch.rand((input_shape, 1024)).half().to(_DEVICE) model = ( nn.Sequential( nn.Linear(1024, 2048), nn.Linear(2048, 1024), ) .half() - .cuda() + .to(_DEVICE) .eval() ) @@ -171,17 +170,16 @@ def test_sparse(self, compile, input_shape): class TestQuantBlockSparseWeight(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "pytorch 2.6+ feature") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse(self, compile): - input = torch.rand((256, 128)).to(torch.bfloat16).cuda() + input = torch.rand((256, 128)).to(torch.bfloat16).to(_DEVICE) model = ( nn.Sequential( nn.Linear(128, 256), nn.Linear(256, 128), ) .to(torch.bfloat16) - .cuda() + .to(_DEVICE) .eval() ) from torchao.sparsity.utils import create_block_sparse_tensor @@ -189,7 +187,7 @@ def test_sparse(self, compile): M, N = model[0].weight.shape model[0].weight.data = ( create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) - * torch.rand(M, N, dtype=torch.bfloat16).cuda() + * torch.rand(M, N, dtype=torch.bfloat16).to(_DEVICE) ) M, N = model[1].weight.shape model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py index 1ef40d12d7..bb00858890 100644 --- a/test/sparsity/test_supermask.py +++ b/test/sparsity/test_supermask.py @@ -11,13 +11,16 @@ from torch import nn from torch.testing._internal import common_utils +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) class TestSupermask(common_utils.TestCase): - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @common_utils.parametrize("sparsity_level", [0.25, 0.5]) @common_utils.parametrize("blocksize", [2, 4, 8]) def test_supermask(self, sparsity_level, blocksize): @@ -26,7 +29,7 @@ def test_supermask(self, sparsity_level, blocksize): nn.Linear(16, 16, bias=False), ) .half() - .cuda() + .to(_DEVICE) .eval() ) @@ -44,7 +47,6 @@ def test_supermask(self, sparsity_level, blocksize): expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level)) assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}" - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_from_linear(self): from torchao.sparsity import SupermaskLinear diff --git a/test/test_ao_models.py b/test/test_ao_models.py index 79e4cc3ef5..9b71019f63 100644 --- a/test/test_ao_models.py +++ b/test/test_ao_models.py @@ -7,8 +7,10 @@ import torch from torchao._models.llama.model import Transformer +from torchao.utils import get_available_devices -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + +_DEVICES = get_available_devices() def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): @@ -17,7 +19,7 @@ def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): return model.eval() -@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("is_training", [True, False]) def test_ao_llama_model_inference_mode(device, batch_size, is_training): diff --git a/test/test_model_architecture.py b/test/test_model_architecture.py index 973939a56a..dc0b6434e5 100644 --- a/test/test_model_architecture.py +++ b/test/test_model_architecture.py @@ -17,8 +17,6 @@ class TestModels(unittest.TestCase): @parameterized.expand([(device,) for device in get_available_devices()]) def test_toy_linear_model(self, device): # Skip if device is not available - if device == "cuda" and not torch.cuda.is_available(): - self.skipTest("CUDA not available") model, input_data = create_model_and_input_data( "linear", 10, 64, 32, device=device @@ -29,8 +27,6 @@ def test_toy_linear_model(self, device): @parameterized.expand([(device,) for device in get_available_devices()]) def test_ln_linear_activation_model(self, device): # Skip if device is not available - if device == "cuda" and not torch.cuda.is_available(): - self.skipTest("CUDA not available") model, input_data = create_model_and_input_data( "ln_linear_sigmoid", 10, 64, 32, device=device @@ -41,8 +37,6 @@ def test_ln_linear_activation_model(self, device): @parameterized.expand([(device,) for device in get_available_devices()]) def test_transformer_block(self, device): # Skip if device is not available - if device == "cuda" and not torch.cuda.is_available(): - self.skipTest("CUDA not available") model, input_data = create_model_and_input_data( "transformer_block", 10, 64, 32, device=device diff --git a/test/test_ops.py b/test/test_ops.py index faec689a69..2a66d20216 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -33,8 +33,13 @@ compute_max_diff, ) +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() + IS_CUDA = torch.cuda.is_available() and torch.version.cuda IS_ROCM = torch.cuda.is_available() and torch.version.hip +IS_XPU = torch.xpu.is_available() and torch.version.xpu try: import torchao.ops @@ -60,7 +65,6 @@ def _create_floatx_inputs( fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) def test_quant_llm_linear(self, ebits, mbits, dtype): @@ -69,7 +73,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype): IC = 256 splitK = 1 floatx_weight, scale, fp16_act = self._create_floatx_inputs( - ebits, mbits, BS, OC, IC, "cuda", dtype + ebits, mbits, BS, OC, IC, _DEVICE, dtype ) # smoke test @@ -90,7 +94,6 @@ def test_quant_llm_linear(self, ebits, mbits, dtype): test_utils=test_utils, ) - @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) @@ -99,7 +102,7 @@ def test_quant_llm_linear_correctness( ): # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py floatx_weight, scale, fp16_act = self._create_floatx_inputs( - ebits, mbits, BS, OC, IC, "cuda", dtype + ebits, mbits, BS, OC, IC, _DEVICE, dtype ) results_floatx = torchao.ops.quant_llm_linear( @@ -287,7 +290,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): N, K = shape assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 - t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + t = torch.randint(0, 16, dtype=torch.int, size=shape, device=_DEVICE) if TORCH_VERSION_AT_LEAST_2_5: t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) @@ -312,7 +315,7 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): if TORCH_VERSION_AT_LEAST_2_5: test_utils.append("test_aot_dispatch_dynamic") - t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + t = torch.randint(0, 16, dtype=torch.int, size=shape, device=_DEVICE) if TORCH_VERSION_AT_LEAST_2_5: t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) @@ -355,7 +358,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( n, k = shape dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE t = torch.randn(n, k, dtype=dtype, device=device) scales, zeros = get_groupwise_affine_qparams( @@ -422,7 +425,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( ): n, k = shape dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE # Quantize and pack t = torch.randn(n, k, dtype=dtype, device=device) @@ -485,7 +488,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( ) def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): n, k = shape - device = "cuda" + device = _DEVICE q = torch.randint(0, 16, shape, dtype=torch.int, device=device) if TORCH_VERSION_AT_LEAST_2_5: @@ -603,9 +606,9 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto size_n = n_chunk * n_factor a_input = torch.randn( - (batch_size, size_m, size_k), dtype=torch.float16, device="cuda" + (batch_size, size_m, size_k), dtype=torch.float16, device=_DEVICE ) - b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") + b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device=_DEVICE) # Inject 2:4 sparsity w_24, _ = inject_24(b_weight, size_k, size_n) diff --git a/test/test_ops_rowwise_scaled_linear_cutlass.py b/test/test_ops_rowwise_scaled_linear_cutlass.py index 72bb201b3f..ef117d021b 100644 --- a/test/test_ops_rowwise_scaled_linear_cutlass.py +++ b/test/test_ops_rowwise_scaled_linear_cutlass.py @@ -17,6 +17,9 @@ _int8_symm_cutlass_quant, ) from torchao.testing.utils import get_compute_capability +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() DTYPES = [torch.float16, torch.bfloat16] BATCH_SIZE = [1, 4, 8, 16, 32, 64] @@ -42,9 +45,9 @@ def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias): size_m, size_n, size_k = size_mnk - X = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") - W = torch.rand((size_n, size_k), dtype=dtype, device="cuda") - bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None + X = torch.randn((batch_size, size_m, size_k), dtype=dtype, device=_DEVICE) + W = torch.rand((size_n, size_k), dtype=dtype, device=_DEVICE) + bias = torch.rand((size_n,), dtype=dtype, device=_DEVICE) if use_bias else None Xq_bits = 4 if op == torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4 else 8 @@ -87,8 +90,6 @@ def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100") @pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op( @@ -100,8 +101,6 @@ def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bia ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100") @pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op( diff --git a/test/test_ops_rowwise_scaled_linear_sparse_cutlass.py b/test/test_ops_rowwise_scaled_linear_sparse_cutlass.py index 938c3337b9..a7b078c83e 100644 --- a/test/test_ops_rowwise_scaled_linear_sparse_cutlass.py +++ b/test/test_ops_rowwise_scaled_linear_sparse_cutlass.py @@ -19,6 +19,9 @@ ) from torchao.sparsity.utils import create_semi_structured_tensor from torchao.testing.utils import skip_if_rocm +from torchao.utils import auto_detect_device + +_DEVICE = auto_detect_device() DTYPES = [torch.float16, torch.bfloat16] XQ_WQ_DTYPES = [ @@ -57,7 +60,7 @@ def run_test_for_op( size_mnk, use_bias, ): - device = "cuda" + device = _DEVICE size_m, size_n, size_k = size_mnk @@ -106,8 +109,6 @@ def run_test_for_op( @skip_if_rocm("does not yet work on ROCm") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") @pytest.mark.parametrize( "dtype, Xq_Wq_dtypes, batch_size, size_mnk, use_bias", TEST_PARAMS, diff --git a/torchao/utils.py b/torchao/utils.py index fb82b9f005..ccdd201644 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -145,6 +145,16 @@ def get_available_devices(): devices.append("mps") return devices +def auto_detect_device(): + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.version, "hip") and torch.version.hip: + return "rocm" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return "xpu" + else: + return "cpu" + def get_compute_capability(): if torch.cuda.is_available():