Skip to content

Commit 35fbbfe

Browse files
Remove @DataClass from ABC base classes and fix search classes
Co-authored-by: ChenZiHong-Gavin <[email protected]>
1 parent b218448 commit 35fbbfe

File tree

12 files changed

+55
-59
lines changed

12 files changed

+55
-59
lines changed

graphgen/bases/base_generator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from abc import ABC, abstractmethod
2-
from dataclasses import dataclass
32
from typing import Any
43

54
from graphgen.bases.base_llm_client import BaseLLMClient
65

76

8-
@dataclass
97
class BaseGenerator(ABC):
108
"""
119
Generate QAs based on given prompts.
1210
"""
1311

14-
llm_client: BaseLLMClient
12+
def __init__(self, llm_client: BaseLLMClient):
13+
self.llm_client = llm_client
1514

1615
@staticmethod
1716
@abstractmethod

graphgen/bases/base_kg_builder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
from abc import ABC, abstractmethod
22
from collections import defaultdict
3-
from dataclasses import dataclass, field
43
from typing import Dict, List, Tuple
54

65
from graphgen.bases.base_llm_client import BaseLLMClient
76
from graphgen.bases.base_storage import BaseGraphStorage
87
from graphgen.bases.datatypes import Chunk
98

109

11-
@dataclass
1210
class BaseKGBuilder(ABC):
13-
llm_client: BaseLLMClient
14-
15-
_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
16-
_edges: Dict[Tuple[str, str], List[dict]] = field(
17-
default_factory=lambda: defaultdict(list)
18-
)
11+
def __init__(self, llm_client: BaseLLMClient):
12+
self.llm_client = llm_client
13+
self._nodes: Dict[str, List[dict]] = defaultdict(list)
14+
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
1915

2016
@abstractmethod
2117
async def extract(

graphgen/bases/base_partitioner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from abc import ABC, abstractmethod
2-
from dataclasses import dataclass
32
from typing import Any, List
43

54
from graphgen.bases.base_storage import BaseGraphStorage
65
from graphgen.bases.datatypes import Community
76

87

9-
@dataclass
108
class BasePartitioner(ABC):
119
@abstractmethod
1210
async def partition(

graphgen/bases/base_splitter.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
import copy
22
import re
33
from abc import ABC, abstractmethod
4-
from dataclasses import dataclass
54
from typing import Callable, Iterable, List, Literal, Optional, Union
65

76
from graphgen.bases.datatypes import Chunk
87
from graphgen.utils import logger
98

109

11-
@dataclass
1210
class BaseSplitter(ABC):
1311
"""
1412
Abstract base class for splitting text into smaller chunks.
1513
"""
1614

17-
chunk_size: int = 1024
18-
chunk_overlap: int = 100
19-
length_function: Callable[[str], int] = len
20-
keep_separator: bool = False
21-
add_start_index: bool = False
22-
strip_whitespace: bool = True
15+
def __init__(
16+
self,
17+
chunk_size: int = 1024,
18+
chunk_overlap: int = 100,
19+
length_function: Callable[[str], int] = len,
20+
keep_separator: bool = False,
21+
add_start_index: bool = False,
22+
strip_whitespace: bool = True,
23+
):
24+
self.chunk_size = chunk_size
25+
self.chunk_overlap = chunk_overlap
26+
self.length_function = length_function
27+
self.keep_separator = keep_separator
28+
self.add_start_index = add_start_index
29+
self.strip_whitespace = strip_whitespace
2330

2431
@abstractmethod
2532
def split_text(self, text: str) -> List[str]:

graphgen/bases/base_storage.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from dataclasses import dataclass
21
from typing import Generic, TypeVar, Union
32

43
T = TypeVar("T")
54

65

7-
@dataclass
86
class StorageNameSpace:
9-
working_dir: str = None
10-
namespace: str = None
7+
def __init__(self, working_dir: str = None, namespace: str = None):
8+
self.working_dir = working_dir
9+
self.namespace = namespace
1110

1211
async def index_done_callback(self):
1312
"""commit the storage operations after indexing"""
@@ -16,7 +15,6 @@ async def query_done_callback(self):
1615
"""commit the storage operations after querying"""
1716

1817

19-
@dataclass
2018
class BaseListStorage(Generic[T], StorageNameSpace):
2119
async def all_items(self) -> list[T]:
2220
raise NotImplementedError
@@ -34,7 +32,6 @@ async def drop(self):
3432
raise NotImplementedError
3533

3634

37-
@dataclass
3835
class BaseKVStorage(Generic[T], StorageNameSpace):
3936
async def all_keys(self) -> list[str]:
4037
raise NotImplementedError
@@ -58,7 +55,6 @@ async def drop(self):
5855
raise NotImplementedError
5956

6057

61-
@dataclass
6258
class BaseGraphStorage(StorageNameSpace):
6359
async def has_node(self, node_id: str) -> bool:
6460
raise NotImplementedError

graphgen/bases/base_tokenizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from dataclasses import dataclass
54
from typing import List
65

76

8-
@dataclass
97
class BaseTokenizer(ABC):
10-
model_name: str = "cl100k_base"
8+
def __init__(self, model_name: str = "cl100k_base"):
9+
self.model_name = model_name
1110

1211
@abstractmethod
1312
def encode(self, text: str) -> List[int]:

graphgen/models/evaluator/base_evaluator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import asyncio
2-
from dataclasses import dataclass
32

43
from tqdm.asyncio import tqdm as tqdm_async
54

65
from graphgen.bases.datatypes import QAPair
76
from graphgen.utils import create_event_loop
87

98

10-
@dataclass
119
class BaseEvaluator:
12-
max_concurrent: int = 100
13-
results: list[float] = None
10+
def __init__(self, max_concurrent: int = 100):
11+
self.max_concurrent = max_concurrent
12+
self.results: list[float] = None
1413

1514
def evaluate(self, pairs: list[QAPair]) -> list[float]:
1615
"""

graphgen/models/llm/topk_token_model.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,39 @@
1-
from dataclasses import dataclass
1+
from abc import ABC, abstractmethod
22
from typing import List, Optional
33

44
from graphgen.bases import Token
55

66

7-
@dataclass
8-
class TopkTokenModel:
9-
do_sample: bool = False
10-
temperature: float = 0
11-
max_tokens: int = 4096
12-
repetition_penalty: float = 1.05
13-
num_beams: int = 1
14-
topk: int = 50
15-
topp: float = 0.95
16-
17-
topk_per_token: int = 5 # number of topk tokens to generate for each token
7+
class TopkTokenModel(ABC):
8+
def __init__(
9+
self,
10+
do_sample: bool = False,
11+
temperature: float = 0,
12+
max_tokens: int = 4096,
13+
repetition_penalty: float = 1.05,
14+
num_beams: int = 1,
15+
topk: int = 50,
16+
topp: float = 0.95,
17+
topk_per_token: int = 5,
18+
):
19+
self.do_sample = do_sample
20+
self.temperature = temperature
21+
self.max_tokens = max_tokens
22+
self.repetition_penalty = repetition_penalty
23+
self.num_beams = num_beams
24+
self.topk = topk
25+
self.topp = topp
26+
self.topk_per_token = topk_per_token
1827

28+
@abstractmethod
1929
async def generate_topk_per_token(self, text: str) -> List[Token]:
2030
"""
2131
Generate prob, text and candidates for each token of the model's output.
2232
This function is used to visualize the inference process.
2333
"""
2434
raise NotImplementedError
2535

36+
@abstractmethod
2637
async def generate_inputs_prob(
2738
self, text: str, history: Optional[List[str]] = None
2839
) -> List[Token]:
@@ -32,6 +43,7 @@ async def generate_inputs_prob(
3243
"""
3344
raise NotImplementedError
3445

46+
@abstractmethod
3547
async def generate_answer(
3648
self, text: str, history: Optional[List[str]] = None
3749
) -> str:

graphgen/models/search/db/uniprot_search.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from dataclasses import dataclass
2-
31
import requests
42
from fastapi import HTTPException
53

@@ -8,7 +6,6 @@
86
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
97

108

11-
@dataclass
129
class UniProtSearch:
1310
"""
1411
UniProt Search client to search with UniProt.

graphgen/models/search/kg/wiki_search.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from dataclasses import dataclass
21
from typing import List, Union
32

43
import wikipedia
@@ -7,7 +6,6 @@
76
from graphgen.utils import detect_main_language, logger
87

98

10-
@dataclass
119
class WikiSearch:
1210
@staticmethod
1311
def set_language(language: str):

0 commit comments

Comments
 (0)