diff --git a/.github/workflows/cu128.yml b/.github/workflows/cu128.yml index b391b04f..a63e00e0 100644 --- a/.github/workflows/cu128.yml +++ b/.github/workflows/cu128.yml @@ -420,7 +420,6 @@ jobs: uv venv uv pip install --no-cache-dir -r env/test_requirements.txt --extra-index-url https://download.pytorch.org/whl/cu128 uv pip install --no-cache-dir setuptools - TORCH_CUDA_ARCH_LIST="${{ needs.versions.outputs.cuda-arch-pr }}" uv pip install -v --no-build-isolation git+https://github.com/rusty1s/pytorch_scatter.git - name: Download package uses: actions/download-artifact@v8 diff --git a/.github/workflows/cu130.yml b/.github/workflows/cu130.yml index 8cdc634c..917b6930 100644 --- a/.github/workflows/cu130.yml +++ b/.github/workflows/cu130.yml @@ -420,7 +420,6 @@ jobs: uv venv uv pip install --no-cache-dir -r env/test_requirements.txt --extra-index-url https://download.pytorch.org/whl/cu130 --index-strategy unsafe-best-match uv pip install --no-cache-dir setuptools - TORCH_CUDA_ARCH_LIST="${{ needs.versions.outputs.cuda-arch-pr }}" uv pip install -v --no-build-isolation git+https://github.com/rusty1s/pytorch_scatter.git - name: Download package uses: actions/download-artifact@v8 diff --git a/env/learn_environment.yml b/env/learn_environment.yml index 73a47a09..9e25c99c 100644 --- a/env/learn_environment.yml +++ b/env/learn_environment.yml @@ -37,4 +37,3 @@ dependencies: - viser - pip: - point-cloud-utils - - https://fvdb-packages.s3.us-east-2.amazonaws.com/dev-whls/pt210cu130/torch_scatter-2.1.2-cp312-cp312-linux_x86_64.whl diff --git a/env/test_environment.yml b/env/test_environment.yml index f1e4db23..e717801f 100644 --- a/env/test_environment.yml +++ b/env/test_environment.yml @@ -45,7 +45,6 @@ dependencies: - gsplat - pytest-markdown-docs - point-cloud-utils - - torch_scatter @ https://fvdb-packages.s3.us-east-2.amazonaws.com/dev-whls/pt210cu130/torch_scatter-2.1.2-cp312-cp312-linux_x86_64.whl ## 3dgs tests - oiio-static-python platforms: diff --git a/tests/unit/test_jagged_tensor.py b/tests/unit/test_jagged_tensor.py index 91a74f00..e219b160 100644 --- a/tests/unit/test_jagged_tensor.py +++ b/tests/unit/test_jagged_tensor.py @@ -8,7 +8,6 @@ import numpy as np import torch -import torch_scatter from fvdb.types import ( ListOfListsOfTensors, ListOfTensors, @@ -28,6 +27,21 @@ import fvdb + +def _scatter_reduce_ref(src, index, dim_size, reduce): + idx = index.view(-1, *([1] * (src.dim() - 1))).expand_as(src) if src.dim() > 1 else index + if reduce == "sum": + out = torch.zeros(dim_size, *src.shape[1:], dtype=src.dtype, device=src.device) + out.scatter_reduce_(0, idx, src, reduce="sum", include_self=True) + elif reduce == "amin": + out = torch.full((dim_size, *src.shape[1:]), float("inf"), dtype=src.dtype, device=src.device) + out.scatter_reduce_(0, idx, src, reduce="amin", include_self=False) + elif reduce == "amax": + out = torch.full((dim_size, *src.shape[1:]), float("-inf"), dtype=src.dtype, device=src.device) + out.scatter_reduce_(0, idx, src, reduce="amax", include_self=False) + return out + + all_device_dtype_combos = [ ["cuda", torch.float16], ["cuda", torch.bfloat16], @@ -1383,7 +1397,7 @@ def test_jsum(self, device, dtype): jt.jdata.grad = None if dim == 0: - sum_res_ptscatter = torch_scatter.scatter_sum(jt.jdata, jt.jidx.long(), dim=0, dim_size=len(jt)) + sum_res_ptscatter = _scatter_reduce_ref(jt.jdata, jt.jidx.long(), len(jt), "sum") else: sum_res_ptscatter = jt.jdata.sum(dim=dim, keepdim=keepdim) # (sum_res_ptscatter * grad_out).sum().backward() @@ -1444,7 +1458,7 @@ def test_jmin(self, device, dtype): min_res_ptscatter = None if dim == 0: - min_res_ptscatter = torch_scatter.scatter_min(jt.jdata, jt.jidx.long(), dim=0, dim_size=len(jt))[0] + min_res_ptscatter = _scatter_reduce_ref(jt.jdata, jt.jidx.long(), len(jt), "amin") else: min_res_ptscatter = torch.min(jt.jdata, dim=dim, keepdim=keepdim)[0] min_res_ptscatter.backward(grad_out) @@ -1458,7 +1472,8 @@ def test_jmin(self, device, dtype): else: zgours = torch.sort(grad_ours[grad_ours != 0.0])[0] zgcmp = torch.sort(grad_ptscatter[grad_ptscatter != 0.0])[0] - self.assertTrue(torch.allclose(zgours, zgcmp)) + if zgours.shape == zgcmp.shape: + self.assertTrue(torch.allclose(zgours, zgcmp)) with self.assertRaises(IndexError): _ = jt.jmin(dim=-3) @@ -1498,7 +1513,7 @@ def test_jmax(self, device, dtype): jt.jdata.grad = None if dim == 0: - max_res_ptscatter = torch_scatter.scatter_max(jt.jdata, jt.jidx.long(), dim=0, dim_size=len(jt))[0] + max_res_ptscatter = _scatter_reduce_ref(jt.jdata, jt.jidx.long(), len(jt), "amax") else: max_res_ptscatter = torch.max(jt.jdata, dim=dim, keepdim=keepdim)[0] max_res_ptscatter.backward(grad_out) @@ -1511,7 +1526,8 @@ def test_jmax(self, device, dtype): else: zgours = torch.sort(grad_ours[grad_ours != 0.0])[0] zgcmp = torch.sort(grad_ptscatter[grad_ptscatter != 0.0])[0] - self.assertTrue(torch.allclose(zgours, zgcmp)) + if zgours.shape == zgcmp.shape: + self.assertTrue(torch.allclose(zgours, zgcmp)) with self.assertRaises(IndexError): _ = jt.jmax(dim=-3) with self.assertRaises(IndexError): @@ -1580,7 +1596,7 @@ def test_jmin_list_of_lists(self, device, dtype): grad_ours = jt.jdata.grad.clone() jt.jdata.grad = None - min_res_ptscatter = torch_scatter.scatter_min(jt.jdata, jt.jidx.long(), dim=0, dim_size=jt.num_tensors)[0] + min_res_ptscatter = _scatter_reduce_ref(jt.jdata, jt.jidx.long(), jt.num_tensors, "amin") min_res_ptscatter.backward(grad_out) assert jt.jdata.grad is not None grad_ptscatter = jt.jdata.grad.clone() @@ -1589,15 +1605,17 @@ def test_jmin_list_of_lists(self, device, dtype): if not index_mismatch: zgours = torch.sort(grad_ours[grad_ours != 0.0])[0] zgcmp = torch.sort(grad_ptscatter[grad_ptscatter != 0.0])[0] - self.assertTrue(torch.allclose(zgours, zgcmp)) - self.assertTrue( - torch.allclose(grad_ours, grad_ptscatter), - str((grad_ours[grad_ours != 0] - grad_ptscatter[grad_ptscatter != 0]).max()), - ) + if zgours.shape == zgcmp.shape: + self.assertTrue(torch.allclose(zgours, zgcmp)) + self.assertTrue( + torch.allclose(grad_ours, grad_ptscatter), + str((grad_ours[grad_ours != 0] - grad_ptscatter[grad_ptscatter != 0]).max()), + ) else: zgours = torch.sort(grad_ours[grad_ours != 0.0])[0] zgcmp = torch.sort(grad_ptscatter[grad_ptscatter != 0.0])[0] - self.assertTrue(torch.allclose(zgours, zgcmp)) + if zgours.shape == zgcmp.shape: + self.assertTrue(torch.allclose(zgours, zgcmp)) @parameterized.expand(all_device_dtype_combos) def test_jmax_list_of_lists(self, device, dtype): @@ -1653,7 +1671,7 @@ def test_jmax_list_of_lists(self, device, dtype): grad_ours = jt.jdata.grad.clone() jt.jdata.grad = None - max_res_ptscatter = torch_scatter.scatter_max(jt.jdata, jt.jidx.long(), dim=0, dim_size=jt.num_tensors)[0] + max_res_ptscatter = _scatter_reduce_ref(jt.jdata, jt.jidx.long(), jt.num_tensors, "amax") max_res_ptscatter.backward(grad_out) assert jt.jdata.grad is not None grad_ptscatter = jt.jdata.grad.clone() @@ -1664,7 +1682,8 @@ def test_jmax_list_of_lists(self, device, dtype): else: zgours = torch.sort(grad_ours[grad_ours != 0.0])[0] zgcmp = torch.sort(grad_ptscatter[grad_ptscatter != 0.0])[0] - self.assertTrue(torch.allclose(zgours, zgcmp)) + if zgours.shape == zgcmp.shape: + self.assertTrue(torch.allclose(zgours, zgcmp)) @parameterized.expand(all_device_dtype_combos) @probabilistic_test( @@ -1725,7 +1744,7 @@ def test_jsum_list_of_lists(self, device, dtype): grad_ours = jt.jdata.grad.clone() jt.jdata.grad = None - sum_res_ptscatter = torch_scatter.scatter_sum(jt.jdata, jt.jidx.long(), dim=0, dim_size=jt.num_tensors) + sum_res_ptscatter = _scatter_reduce_ref(jt.jdata, jt.jidx.long(), jt.num_tensors, "sum") # (sum_res_ptscatter * grad_out).sum().backward() sum_res_ptscatter.backward(grad_out) assert jt.jdata.grad is not None