Skip to content

Commit b7e883c

Browse files
authored
Merge pull request #39 from PySATL/new_stats
feat: add new stats for graph stats - clique number and independence number
2 parents a57739e + 70584d1 commit b7e883c

File tree

8 files changed

+279
-90
lines changed

8 files changed

+279
-90
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
runs-on: ubuntu-latest
1515
strategy:
1616
matrix:
17-
python-version: [ "3.10", "3.11", "3.12" ]
17+
python-version: [ "3.11", "3.12" ]
1818

1919
steps:
2020
- name: Checkout repository

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ authors = [
77
]
88
license = {text = "MIT"}
99
readme = "README.md"
10-
requires-python = ">=3.9,<3.13"
10+
requires-python = ">=3.11,<3.13"
1111
dependencies = [
1212
"numpy>=1.25.1",
1313
"scipy>=1.11.2",
1414
"pandas>=2.2.1",
15-
"typing-extensions>=4.12.2"
15+
"typing-extensions>=4.12.2",
16+
"networkx (>=3.5,<4.0)"
1617
]
1718

1819
[project.urls]

pysatl_criterion/persistence/limit_distribution/sqlite/sqlite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_data(self, query: LimitDistributionQuery) -> LimitDistributionModel | No
120120
return None
121121

122122
columns = [col[0] for col in cursor.description]
123-
return self._row_to_model(dict(zip(columns, row)))
123+
return self._row_to_model(dict(zip(columns, row, strict=False)))
124124

125125
def delete_data(self, query: LimitDistributionQuery) -> None:
126126
"""Delete specific limit distribution data."""
@@ -164,4 +164,4 @@ def get_data_for_cv(self, query: CriticalValueQuery) -> LimitDistributionModel |
164164
return None
165165

166166
columns = [col[0] for col in cursor.description]
167-
return self._row_to_model(dict(zip(columns, row)))
167+
return self._row_to_model(dict(zip(columns, row, strict=False)))

pysatl_criterion/statistics/exponent.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from pysatl_criterion.statistics.common import KSStatistic
1010
from pysatl_criterion.statistics.goodness_of_fit import AbstractGoodnessOfFitStatistic
1111
from pysatl_criterion.statistics.graph_goodness_of_fit import (
12+
AbstractGraphTestStatistic,
1213
GraphAverageDegreeTestStatistic,
14+
GraphCliqueNumberTestStatistic,
1315
GraphConnectedComponentsTestStatistic,
1416
GraphEdgesNumberTestStatistic,
17+
GraphIndependenceNumberTestStatistic,
1518
GraphMaxDegreeTestStatistic,
1619
)
1720

@@ -810,73 +813,78 @@ def execute_statistic(self, rvs, **kwargs):
810813
return hg
811814

812815

813-
class GraphEdgesNumberExponentialityGofStatistic(
814-
AbstractExponentialityGofStatistic, GraphEdgesNumberTestStatistic
816+
class AbstractGraphExponentialityGofStatistic(
817+
AbstractExponentialityGofStatistic, AbstractGraphTestStatistic
815818
):
816819
@staticmethod
817820
@override
818821
def code():
819-
super_class = AbstractExponentialityGofStatistic
820-
parent_code = super(super_class, super_class).code()
821-
return f"EdgesNumber_{parent_code}"
822+
parent_code = AbstractExponentialityGofStatistic.code()
823+
return f"GRAPH_{parent_code}"
822824

823825
@staticmethod
824826
@override
825827
def _compute_dist(rvs):
826-
super_class = GraphEdgesNumberTestStatistic
827-
parent_code = super(super_class, super_class)._compute_dist(rvs)
828-
return parent_code / np.mean(rvs)
828+
base_dist = AbstractGraphTestStatistic._compute_dist(rvs)
829+
mean = np.mean(rvs)
830+
return base_dist / mean if mean != 0 else base_dist
829831

830832

831-
class GraphMaxDegreeExponentialityGofStatistic(
832-
AbstractExponentialityGofStatistic, GraphMaxDegreeTestStatistic
833+
class GraphEdgesNumberExponentialityGofStatistic(
834+
AbstractGraphExponentialityGofStatistic, GraphEdgesNumberTestStatistic
833835
):
834836
@staticmethod
835837
@override
836838
def code():
837-
super_class = AbstractExponentialityGofStatistic
838-
parent_code = super(super_class, super_class).code()
839-
return f"MaxDegree_{parent_code}"
839+
parent_code = AbstractGraphExponentialityGofStatistic.code()
840+
return f"{GraphEdgesNumberExponentialityGofStatistic.get_stat_name()}_{parent_code}"
840841

842+
843+
class GraphMaxDegreeExponentialityGofStatistic(
844+
AbstractGraphExponentialityGofStatistic, GraphMaxDegreeTestStatistic
845+
):
841846
@staticmethod
842847
@override
843-
def _compute_dist(rvs):
844-
super_class = GraphMaxDegreeTestStatistic
845-
parent_code = super(super_class, super_class)._compute_dist(rvs)
846-
return parent_code / np.mean(rvs)
848+
def code():
849+
parent_code = AbstractGraphExponentialityGofStatistic.code()
850+
return f"{GraphMaxDegreeExponentialityGofStatistic.get_stat_name()}_{parent_code}"
847851

848852

849853
class GraphAverageDegreeExponentialityGofStatistic(
850-
AbstractExponentialityGofStatistic, GraphAverageDegreeTestStatistic
854+
AbstractGraphExponentialityGofStatistic, GraphAverageDegreeTestStatistic
851855
):
852856
@staticmethod
853857
@override
854858
def code():
855-
super_class = AbstractExponentialityGofStatistic
856-
parent_code = super(super_class, super_class).code()
857-
return f"AverageDegree_{parent_code}"
859+
parent_code = AbstractGraphExponentialityGofStatistic.code()
860+
return f"{GraphAverageDegreeExponentialityGofStatistic.get_stat_name()}_{parent_code}"
861+
858862

863+
class GraphConnectedComponentsExponentialityGofStatistic(
864+
AbstractGraphExponentialityGofStatistic, GraphConnectedComponentsTestStatistic
865+
):
859866
@staticmethod
860867
@override
861-
def _compute_dist(rvs):
862-
super_class = GraphAverageDegreeTestStatistic
863-
parent_dist = super(super_class, super_class)._compute_dist(rvs)
864-
return parent_dist / np.mean(rvs)
868+
def code():
869+
parent_code = AbstractGraphExponentialityGofStatistic.code()
870+
return f"{GraphConnectedComponentsExponentialityGofStatistic.get_stat_name()}_{parent_code}"
865871

866872

867-
class GraphConnectedComponentsExponentialityGofStatistic(
868-
AbstractExponentialityGofStatistic, GraphConnectedComponentsTestStatistic
873+
class GraphCliqueNumberExponentialityGofStatistic(
874+
AbstractGraphExponentialityGofStatistic, GraphCliqueNumberTestStatistic
869875
):
870876
@staticmethod
871877
@override
872878
def code():
873-
super_class = AbstractExponentialityGofStatistic
874-
parent_code = super(super_class, super_class).code()
875-
return f"ConnectedComponents_{parent_code}"
879+
parent_code = AbstractGraphExponentialityGofStatistic.code()
880+
return f"{GraphCliqueNumberExponentialityGofStatistic.get_stat_name()}_{parent_code}"
876881

882+
883+
class GraphIndependenceNumberExponentialityGofStatistic(
884+
AbstractGraphExponentialityGofStatistic, GraphIndependenceNumberTestStatistic
885+
):
877886
@staticmethod
878887
@override
879-
def _compute_dist(rvs):
880-
super_class = GraphConnectedComponentsTestStatistic
881-
parent_dist = super(super_class, super_class)._compute_dist(rvs)
882-
return parent_dist / np.mean(rvs)
888+
def code():
889+
parent_code = AbstractGraphExponentialityGofStatistic.code()
890+
return f"{GraphIndependenceNumberExponentialityGofStatistic.get_stat_name()}_{parent_code}"

pysatl_criterion/statistics/graph_goodness_of_fit.py

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import ABC
2-
from typing import Union
32

43
import numpy as np
54
from numpy import float64
@@ -10,13 +9,21 @@
109

1110
class AbstractGraphTestStatistic(AbstractGoodnessOfFitStatistic, ABC):
1211
@override
13-
def execute_statistic(self, rvs, **kwargs) -> Union[float, float64]:
12+
def execute_statistic(self, rvs, **kwargs) -> float | float64:
1413
dist = self._compute_dist(rvs)
1514

1615
adjacency_list = self._make_adjacency_list(rvs, dist)
1716
statistic = self.get_graph_stat(adjacency_list)
1817
return statistic
1918

19+
@staticmethod
20+
def get_stat_name() -> str:
21+
raise NotImplementedError("Method is not implemented")
22+
23+
@staticmethod
24+
def get_graph_stat(graph: list[list[int]]) -> float:
25+
raise NotImplementedError("Method is not implemented")
26+
2027
@staticmethod
2128
def _make_adjacency_list(rvs, dist: float) -> list[list[int]]:
2229
adjacency_list: list[list[int]] = []
@@ -31,40 +38,51 @@ def _make_adjacency_list(rvs, dist: float) -> list[list[int]]:
3138
return adjacency_list
3239

3340
@staticmethod
34-
def _compute_dist(rvs): # TODO (normalize for different distributions)
41+
def _compute_dist(rvs: list[float]) -> float: # TODO (normalize for different distributions)
3542
return (max(rvs) - min(rvs)) / 10
3643

37-
@staticmethod
38-
def get_graph_stat(graph: list[list[int]]):
39-
raise NotImplementedError("Method is not implemented")
4044

41-
42-
class GraphEdgesNumberTestStatistic(AbstractGraphTestStatistic, ABC):
45+
class GraphEdgesNumberTestStatistic(AbstractGraphTestStatistic):
4346
@staticmethod
4447
@override
45-
def get_graph_stat(graph):
48+
def get_graph_stat(graph: list[list[int]]) -> float:
4649
return sum(map(len, graph)) // 2
4750

51+
@staticmethod
52+
@override
53+
def get_stat_name() -> str:
54+
return "EDGESNUMBER"
55+
4856

49-
class GraphMaxDegreeTestStatistic(AbstractGraphTestStatistic, ABC):
57+
class GraphMaxDegreeTestStatistic(AbstractGraphTestStatistic):
5058
@staticmethod
5159
@override
52-
def get_graph_stat(graph):
60+
def get_graph_stat(graph: list[list[int]]) -> float:
5361
return max(map(len, graph))
5462

63+
@staticmethod
64+
@override
65+
def get_stat_name() -> str:
66+
return "MAXDEGREE"
67+
5568

56-
class GraphAverageDegreeTestStatistic(AbstractGraphTestStatistic, ABC):
69+
class GraphAverageDegreeTestStatistic(AbstractGraphTestStatistic):
5770
@staticmethod
5871
@override
59-
def get_graph_stat(graph):
72+
def get_graph_stat(graph: list[list[int]]) -> float:
6073
degrees = list(map(len, graph))
61-
return np.mean(degrees) if degrees != 0 else 0.0
74+
return float(np.mean(degrees)) if degrees != 0 else 0.0
75+
76+
@staticmethod
77+
@override
78+
def get_stat_name() -> str:
79+
return "AVGDEGREE"
6280

6381

64-
class GraphConnectedComponentsTestStatistic(AbstractGraphTestStatistic, ABC):
82+
class GraphConnectedComponentsTestStatistic(AbstractGraphTestStatistic):
6583
@staticmethod
6684
@override
67-
def get_graph_stat(graph):
85+
def get_graph_stat(graph) -> float:
6886
visited = set()
6987
components = 0
7088

@@ -81,3 +99,57 @@ def dfs(node):
8199
dfs(node)
82100
components += 1
83101
return components
102+
103+
@staticmethod
104+
@override
105+
def get_stat_name() -> str:
106+
return "CONNECTEDCOMPONENTS"
107+
108+
109+
class GraphCliqueNumberTestStatistic(AbstractGraphTestStatistic):
110+
@override
111+
def execute_statistic(self, rvs, **kwargs) -> float | float64:
112+
dist = self._compute_dist(rvs)
113+
rvs.sort()
114+
115+
right_border = 0
116+
clique_number = 0
117+
for left_border in range(len(rvs)):
118+
while right_border < len(rvs) and rvs[left_border] + dist > rvs[right_border]:
119+
right_border += 1
120+
if right_border == len(rvs):
121+
clique_number = max(clique_number, right_border - left_border + 1)
122+
break
123+
clique_number = max(clique_number, right_border - left_border)
124+
return clique_number
125+
126+
@staticmethod
127+
@override
128+
def get_stat_name() -> str:
129+
return "CLIQUENUMBER"
130+
131+
132+
class GraphIndependenceNumberTestStatistic(AbstractGraphTestStatistic):
133+
@override
134+
def execute_statistic(self, rvs, **kwargs) -> float | float64:
135+
if not rvs:
136+
return 0
137+
138+
dist = self._compute_dist(rvs)
139+
rvs.sort()
140+
141+
stat = 1
142+
last_chosen_position = rvs[0]
143+
144+
for i in range(1, len(rvs)):
145+
current_point = rvs[i]
146+
if current_point >= last_chosen_position + dist:
147+
stat += 1
148+
last_chosen_position = current_point
149+
150+
return stat
151+
152+
@staticmethod
153+
@override
154+
def get_stat_name() -> str:
155+
return "INDEPENDENCENUMBER"

0 commit comments

Comments
 (0)