Skip to content

Commit 0ed7f49

Browse files
Remove @DataClass from all subclasses following generator pattern
Co-authored-by: ChenZiHong-Gavin <[email protected]>
1 parent e5954e9 commit 0ed7f49

File tree

11 files changed

+16
-47
lines changed

11 files changed

+16
-47
lines changed

graphgen/models/evaluator/length_evaluator.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
from dataclasses import dataclass
2-
31
from graphgen.bases.datatypes import QAPair
42
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
53
from graphgen.models.tokenizer import Tokenizer
64
from graphgen.utils import create_event_loop
75

86

9-
@dataclass
107
class LengthEvaluator(BaseEvaluator):
11-
tokenizer_name: str = "cl100k_base"
12-
13-
def __post_init__(self):
8+
def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
9+
super().__init__(max_concurrent)
10+
self.tokenizer_name = tokenizer_name
1411
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
1512

1613
async def evaluate_single(self, pair: QAPair) -> float:

graphgen/models/evaluator/mtld_evaluator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from dataclasses import dataclass, field
21
from typing import Set
32

43
from graphgen.bases.datatypes import QAPair
@@ -8,18 +7,15 @@
87
nltk_helper = NLTKHelper()
98

109

11-
@dataclass
1210
class MTLDEvaluator(BaseEvaluator):
1311
"""
1412
衡量文本词汇多样性的指标
1513
"""
1614

17-
stopwords_en: Set[str] = field(
18-
default_factory=lambda: set(nltk_helper.get_stopwords("english"))
19-
)
20-
stopwords_zh: Set[str] = field(
21-
default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))
22-
)
15+
def __init__(self, max_concurrent: int = 100):
16+
super().__init__(max_concurrent)
17+
self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english"))
18+
self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese"))
2319

2420
async def evaluate_single(self, pair: QAPair) -> float:
2521
loop = create_event_loop()

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import re
22
from collections import Counter, defaultdict
3-
from dataclasses import dataclass
43
from typing import Dict, List, Tuple
54

65
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
@@ -15,10 +14,10 @@
1514
)
1615

1716

18-
@dataclass
1917
class LightRAGKGBuilder(BaseKGBuilder):
20-
llm_client: BaseLLMClient = None
21-
max_loop: int = 3
18+
def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
19+
super().__init__(llm_client)
20+
self.max_loop = max_loop
2221

2322
async def extract(
2423
self, chunk: Chunk

graphgen/models/kg_builder/mm_kg_builder.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import re
22
from collections import defaultdict
3-
from dataclasses import dataclass
43
from typing import Dict, List, Tuple
54

65
from graphgen.bases import BaseLLMClient, Chunk
@@ -16,11 +15,7 @@
1615
from .light_rag_kg_builder import LightRAGKGBuilder
1716

1817

19-
@dataclass
2018
class MMKGBuilder(LightRAGKGBuilder):
21-
llm_client: BaseLLMClient = None
22-
max_loop: int = 3
23-
2419
async def extract(
2520
self, chunk: Chunk
2621
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:

graphgen/models/partitioner/bfs_partitioner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import random
22
from collections import deque
3-
from dataclasses import dataclass
43
from typing import Any, List
54

65
from graphgen.bases import BaseGraphStorage, BasePartitioner
@@ -10,7 +9,6 @@
109
EDGE_UNIT: str = "e"
1110

1211

13-
@dataclass
1412
class BFSPartitioner(BasePartitioner):
1513
"""
1614
BFS partitioner that partitions the graph into communities of a fixed size.

graphgen/models/partitioner/dfs_partitioner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import random
2-
from dataclasses import dataclass
32
from typing import Any, List
43

54
from graphgen.bases import BaseGraphStorage, BasePartitioner
@@ -9,7 +8,6 @@
98
EDGE_UNIT: str = "e"
109

1110

12-
@dataclass
1311
class DFSPartitioner(BasePartitioner):
1412
"""
1513
DFS partitioner that partitions the graph into communities of a fixed size.

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import random
3-
from dataclasses import dataclass
43
from typing import Any, Dict, List, Optional, Set, Tuple
54

65
from tqdm.asyncio import tqdm as tqdm_async
@@ -13,7 +12,6 @@
1312
EDGE_UNIT: str = "e"
1413

1514

16-
@dataclass
1715
class ECEPartitioner(BFSPartitioner):
1816
"""
1917
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).

graphgen/models/partitioner/leiden_partitioner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections import defaultdict
2-
from dataclasses import dataclass
32
from typing import Any, Dict, List, Set, Tuple
43

54
import igraph as ig
@@ -9,7 +8,6 @@
98
from graphgen.bases.datatypes import Community
109

1110

12-
@dataclass
1311
class LeidenPartitioner(BasePartitioner):
1412
"""
1513
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.

graphgen/models/tokenizer/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from dataclasses import dataclass, field
21
from typing import List
32

43
from graphgen.bases import BaseTokenizer
@@ -30,16 +29,13 @@ def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
3029
)
3130

3231

33-
@dataclass
3432
class Tokenizer(BaseTokenizer):
3533
"""
3634
Encapsulates different tokenization implementations based on the specified model name.
3735
"""
3836

39-
model_name: str = "cl100k_base"
40-
_impl: BaseTokenizer = field(init=False, repr=False)
41-
42-
def __post_init__(self):
37+
def __init__(self, model_name: str = "cl100k_base"):
38+
super().__init__(model_name)
4339
if not self.model_name:
4440
raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
4541
self._impl = get_tokenizer_impl(self.model_name)

graphgen/models/tokenizer/hf_tokenizer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
from dataclasses import dataclass
21
from typing import List
32

43
from transformers import AutoTokenizer
54

65
from graphgen.bases import BaseTokenizer
76

87

9-
@dataclass
108
class HFTokenizer(BaseTokenizer):
11-
model_name: str = "cl100k_base"
12-
13-
def __post_init__(self):
9+
def __init__(self, model_name: str = "cl100k_base"):
10+
super().__init__(model_name)
1411
self.enc = AutoTokenizer.from_pretrained(self.model_name)
1512

1613
def encode(self, text: str) -> List[int]:

0 commit comments

Comments
 (0)