diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 718d8668..95285a6d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: hooks: - id: yamlfmt - repo: https://github.com/timothycrosley/isort - rev: 5.7.0 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/psf/black @@ -58,7 +58,7 @@ repos: hooks: - id: black language_version: python3.8 - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 3.8.4 hooks: - id: flake8 diff --git a/programl/requirements.txt b/programl/requirements.txt index db4d6937..28f0f482 100644 --- a/programl/requirements.txt +++ b/programl/requirements.txt @@ -1,5 +1,5 @@ absl-py>=0.11.0 -dgl>=0.6.1,<=0.9.1 +dgl==1.1.1 grpcio>=1.33.2 networkx>=2.4 numpy>=1.19.3 diff --git a/programl/transform_ops.py b/programl/transform_ops.py index 6c76734b..420e1171 100644 --- a/programl/transform_ops.py +++ b/programl/transform_ops.py @@ -22,7 +22,7 @@ import dgl import networkx as nx -from dgl.heterograph import DGLHeteroGraph +from dgl import DGLGraph from networkx.readwrite import json_graph as nx_json from programl.exceptions import GraphTransformError @@ -168,7 +168,7 @@ def to_dgl( timeout: int = 300, executor: Optional[ExecutorLike] = None, chunksize: Optional[int] = None, -) -> Union[DGLHeteroGraph, Iterable[DGLHeteroGraph]]: +) -> Union[DGLGraph, Iterable[DGLGraph]]: """Convert one or more Program Graphs to `DGLGraphs `_. @@ -201,7 +201,7 @@ def to_dgl( """ def _run_one(nx_graph): - return dgl.DGLGraph(nx_graph) + return dgl.from_networkx(nx_graph) if isinstance(graphs, ProgramGraph): return _run_one(to_networkx(graphs)) diff --git a/tests/BUILD b/tests/BUILD index 211ec444..ba238258 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -134,3 +134,13 @@ py_test( "//tests/plugins", ], ) + +py_test( + name = "to_dgl_test", + srcs = ["to_dgl_test.py"], + deps = [ + "//programl", + "//tests:test_main", + "//tests/plugins", + ], +) diff --git a/tests/to_dgl_test.py b/tests/to_dgl_test.py new file mode 100644 index 00000000..371cca55 --- /dev/null +++ b/tests/to_dgl_test.py @@ -0,0 +1,32 @@ +import pytest +from dgl import DGLGraph + +import programl as pg +from tests.test_main import main + +pytest_plugins = ["tests.plugins.llvm_program_graph"] + + +@pytest.fixture(scope="session") +def graph() -> pg.ProgramGraph: + return pg.from_cpp("int A() { return 0; }") + + +def test_to_dgl_simple_graph(graph: pg.ProgramGraph): + graphs = list(pg.to_dgl([graph])) + assert len(graphs) == 1 + assert isinstance(graphs[0], DGLGraph) + + +def test_to_dgl_simple_graph_single_input(graph: pg.ProgramGraph): + dgl_graph = pg.to_dgl(graph) + assert isinstance(dgl_graph, DGLGraph) + + +def test_to_dgl_two_inputs(graph: pg.ProgramGraph): + graphs = list(pg.to_dgl([graph, graph])) + assert len(graphs) == 2 + + +if __name__ == "__main__": + main() diff --git a/tools/perf_monitor/requirements.txt b/tools/perf_monitor/requirements.txt index c271188b..5fdaf887 100644 --- a/tools/perf_monitor/requirements.txt +++ b/tools/perf_monitor/requirements.txt @@ -1,2 +1,2 @@ GPUtil==1.4.0 -psutil==5.4.5 +psutil==5.8.0