Skip to content

Commit eb19b8d

Browse files
fix(models): fix lint errors
1 parent 4f169db commit eb19b8d

File tree

7 files changed

+21
-18
lines changed

7 files changed

+21
-18
lines changed

models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@
3838
"UniEvaluator",
3939
# strategy models
4040
"TraverseStrategy",
41-
]
41+
]

models/embed/embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
from dataclasses import dataclass
12
import asyncio
23
import numpy as np
34

4-
from dataclasses import dataclass
5-
65
class UnlimitedSemaphore:
76
"""A context manager that allows unlimited access."""
87

models/evaluate/length_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from .base_evaluator import BaseEvaluator
2+
from models.evaluate.base_evaluator import BaseEvaluator
33
from models.llm.tokenizer import Tokenizer
44
from models.text.text_pair import TextPair
55
from utils import create_event_loop
@@ -16,7 +16,7 @@ def __post_init__(self):
1616
async def evaluate_single(self, pair: TextPair) -> float:
1717
loop = create_event_loop()
1818
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
19-
19+
2020
def _calculate_length(self, text: str) -> float:
2121
tokens = self.tokenizer.encode_string(text)
2222
return len(tokens)

models/llm/openai_model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import math
22
from dataclasses import dataclass
33
from typing import List, Dict, Optional
4-
from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError, ChatCompletion
5-
from models import TopkTokenModel, Token
4+
import openai
5+
from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError
66
from tenacity import (
77
retry,
88
stop_after_attempt,
99
wait_exponential,
1010
retry_if_exception_type,
1111
)
1212

13-
def get_top_response_tokens(response: ChatCompletion) -> List[Token]:
13+
from models import TopkTokenModel, Token
14+
15+
16+
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
1417
token_logprobs = response.choices[0].logprobs.content
1518
tokens = []
1619
for token_prob in token_logprobs:
@@ -76,6 +79,7 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
7679

7780
completion = await self.client.chat.completions.create(
7881
model=self.model_name,
82+
messages=kwargs["messages"],
7983
**kwargs
8084
)
8185

@@ -94,7 +98,11 @@ async def generate_answer(self, text: str, history: Optional[List[str]] = None,
9498

9599
completion = await self.client.chat.completions.create(
96100
model=self.model_name,
101+
messages=kwargs["messages"],
97102
**kwargs
98103
)
99104

100105
return completion.choices[0].message.content
106+
107+
async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
108+
raise NotImplementedError

models/storage/base_storage.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ class StorageNameSpace:
1111

1212
async def index_done_callback(self):
1313
"""commit the storage operations after indexing"""
14-
pass
1514

1615
async def query_done_callback(self):
1716
"""commit the storage operations after querying"""
18-
pass
17+
1918

2019
@dataclass
2120
class BaseKVStorage(Generic[T], StorageNameSpace):

models/storage/json_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def get_by_ids(self, ids, fields=None):
3636
]
3737

3838
async def filter_keys(self, data: list[str]) -> set[str]:
39-
return set([s for s in data if s not in self._data])
39+
return {s for s in data if s not in self._data}
4040

4141
async def upsert(self, data: dict):
4242
left_data = {k: v for k, v in data.items() if k not in self._data}

models/storage/networkx_storage.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ def load_nx_graph(file_name) -> Optional[nx.Graph]:
1717

1818
@staticmethod
1919
def write_nx_graph(graph: nx.Graph, file_name):
20-
logger.info(
21-
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
22-
)
20+
logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges())
2321
nx.write_graphml(graph, file_name)
2422

2523
@staticmethod
@@ -56,9 +54,7 @@ def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
5654
def _sort_source_target(edge):
5755
source, target, edge_data = edge
5856
if source > target:
59-
temp = source
60-
source = target
61-
target = temp
57+
source, target = target, source
6258
return source, target, edge_data
6359

6460
edges = [_sort_source_target(edge) for edge in edges]
@@ -81,7 +77,8 @@ def __post_init__(self):
8177
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
8278
if preloaded_graph is not None:
8379
logger.info(
84-
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
80+
"Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file,
81+
preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges()
8582
)
8683
self._graph = preloaded_graph or nx.Graph()
8784

0 commit comments

Comments
 (0)