Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ quote-style = "single"
convention = "google"

[tool.pytest.ini_options]
testpaths = ["test"]
addopts = [
"--capture=no",
"--color=yes",
Expand All @@ -207,6 +208,8 @@ filterwarnings = [
"error::DeprecationWarning",
# TODO(rishipuri98): Remove usage of `torch_geometric.distributed` from `torch_geometric.llm`
"ignore:.*torch_geometric.distributed.*:DeprecationWarning",
# Ignore due to pytorch-lightning using a deprecated pkg_resources API:
"ignore:.*Deprecated call to `pkg_resources.declare_namespace.*:DeprecationWarning",
]
markers = [
"rag: mark test as RAG test",
Expand Down
15 changes: 15 additions & 0 deletions test/explain/algorithm/test_attention_explainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os

import pytest
import torch

import torch_geometric.typing
from torch_geometric.explain import (
AttentionExplainer,
Explainer,
Expand Down Expand Up @@ -180,6 +183,10 @@ def forward(self, x_dict, edge_index_dict):
batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])


@pytest.mark.skipif(
os.name == 'nt' and torch_geometric.typing.WITH_PT24,
reason="Unknown heap corruption issue on Windows with PyTorch 2.4",
)
@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])
def test_attention_explainer(index, check_explanation):
explainer = Explainer(
Expand Down Expand Up @@ -216,6 +223,10 @@ def test_attention_explainer_supports(explanation_type, node_mask_type):
)


@pytest.mark.skipif(
os.name == 'nt' and torch_geometric.typing.WITH_PT24,
reason="Unknown heap corruption issue on Windows with PyTorch 2.4",
)
def test_attention_explainer_attentive_fp(check_explanation):
model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2)

Expand All @@ -235,6 +246,10 @@ def test_attention_explainer_attentive_fp(check_explanation):
check_explanation(explanation, None, explainer.edge_mask_type)


@pytest.mark.skipif(
os.name == 'nt' and torch_geometric.typing.WITH_PT24,
reason="Unknown heap corruption issue on Windows with PyTorch 2.4",
)
@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])
def test_attention_explainer_hetero(index, hetero_data,
check_explanation_hetero):
Expand Down
4 changes: 3 additions & 1 deletion test/nn/conv/test_point_transformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import torch_geometric.typing
from torch_geometric.nn import PointTransformerConv
from torch_geometric.testing import is_full_test
from torch_geometric.testing import is_full_test, withPackage
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import to_torch_csc_tensor


# Skip on PyTorch 1.13 due to numerical instability:
@withPackage('torch>=2.0.0')
def test_point_transformer_conv():
x1 = torch.rand(4, 16)
x2 = torch.randn(2, 8)
Expand Down
2 changes: 1 addition & 1 deletion test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_multi_aggr_sage_conv(aggr_kwargs):

@withDevice
@onlyLinux
@withPackage('torch>=2.1.0')
@withPackage('torch>=2.1.0', 'torch!=2.3.0', 'torch!=2.3.1')
def test_compile_multi_aggr_sage_conv(device):
import torch._dynamo as dynamo

Expand Down
2 changes: 1 addition & 1 deletion test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_basic_gnn_inference(get_dataset, jk):
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.0.0')
@withPackage('torch>=2.0.0', 'torch!=2.3.0', 'torch!=2.3.1')
def test_compile_basic(device):
x = torch.randn(3, 8, device=device)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)
Expand Down
2 changes: 1 addition & 1 deletion test/nn/models/test_lpformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_geometric.utils import to_undirected


@withPackage('numba') # For ppr calculation
@withPackage('numba', 'torch>=2.0.0') # For ppr calculation
def test_lpformer():
model = LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)
assert str(
Expand Down
2 changes: 1 addition & 1 deletion test/nn/test_compile_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def fused_gather_scatter(x, edge_index, reduce=('sum', 'mean', 'max')):
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.0.0')
@withPackage('torch>=2.0.0', 'torch!=2.3.0', 'torch!=2.3.1')
def test_torch_compile(device):
x = torch.randn(10, 16, device=device)
edge_index = torch.randint(0, x.size(0), (2, 40), device=device)
Expand Down
4 changes: 2 additions & 2 deletions test/nn/test_compile_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.1.0')
@withPackage('torch>=2.1.0', 'torch!=2.3.0', 'torch!=2.3.1')
@pytest.mark.parametrize('Conv', [GCNConv, SAGEConv])
def test_compile_conv(device, Conv):
import torch._dynamo as dynamo
Expand All @@ -52,7 +52,7 @@ def test_compile_conv(device, Conv):
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch==2.3')
@withPackage('torch>=2.4.0')
@pytest.mark.parametrize('Conv', [GCNConv, SAGEConv])
def test_compile_conv_edge_index(device, Conv):
import torch._dynamo as dynamo
Expand Down
2 changes: 1 addition & 1 deletion test/nn/test_compile_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>2.0.0')
@withPackage('torch>2.0.0', 'torch!=2.3.0', 'torch!=2.3.1')
def test_dynamic_torch_compile(device):
conv = MySAGEConv(64, 64).to(device)
conv = torch.compile(conv, dynamic=True)
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/data/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def extract_tar(
"""
maybe_log(path, log)
with tarfile.open(path, mode) as f:
f.extractall(folder, filter='data')
if sys.version_info >= (3, 12):
f.extractall(folder, filter='data')
else:
f.extractall(folder)


def extract_zip(path: str, folder: str, log: bool = True) -> None:
Expand Down
Loading