Skip to content

Commit e92bb5c

Browse files
authored
Merge pull request #14 from Open-DataFlow/sunnyhaze
Revise calling style to nn.module style
2 parents 2da25fd + 223424f commit e92bb5c

39 files changed

+680
-276
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@
44
*.egg
55
*.egg-info
66
/dataflow/example/ReasoningPipeline/pipeline_math_step*.jsonl
7-
!/dataflow/example/ReasoningPipeline/pipeline_math.json
7+
!/dataflow/example/ReasoningPipeline/pipeline_math.json
8+
9+
test/example
10+
11+
cache

dataflow/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from .utils import *
22
from .version import __version__, version_info
3+
from .logger import get_logger
34

4-
5-
__all__ = ['__version__', 'version_info']
5+
__all__ = [
6+
'__version__',
7+
'version_info',
8+
'get_logger',
9+
]
610

711

812

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from typing import Any, List
33

44

5-
class Generator(ABC):
5+
class GeneratorABC(ABC):
6+
"""Abstract base class for data generators. Which may be used to generate data from a model or API. Called by operators
7+
"""
68
@abstractmethod
79
def generate(self) -> Any:
810
"""
@@ -12,7 +14,7 @@ def generate(self) -> Any:
1214
pass
1315

1416
@abstractmethod
15-
def generate_from_input(self, input: List[str]) -> List[str]:
17+
def generate_from_input(self, input: List[str], system_prompt: str) -> List[str]:
1618
"""
1719
Generate data from input.
1820
input: List[str], the input of the generator

dataflow/core/Operator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from abc import ABC, abstractmethod
2+
from dataflow.logger import get_logger
3+
4+
class OperatorABC(ABC):
5+
6+
# @abstractmethod
7+
# def check_config(self, config: dict) -> None:
8+
# """
9+
# Check the config of the operator. If config lacks any required keys, raise an error.
10+
# """
11+
# pass
12+
13+
@abstractmethod
14+
def run(self) -> None:
15+
"""
16+
Main function to run the operator.
17+
"""
18+
pass
19+
20+
def get_operator(operator_name, args) -> OperatorABC:
21+
from dataflow.utils import OPERATOR_REGISTRY
22+
print(operator_name, args)
23+
operator = OPERATOR_REGISTRY.get(operator_name)(args)
24+
logger = get_logger()
25+
if operator is not None:
26+
logger.info(f"Successfully get operator {operator_name}, args {args}")
27+
else:
28+
logger.error(f"operator {operator_name} is not found")
29+
assert operator is not None
30+
print(operator)
31+
return operator

dataflow/core/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .Operator import OperatorABC, get_operator
2+
from .Generator import GeneratorABC
3+
__all__ = [
4+
'OperatorABC',
5+
'get_operator',
6+
'GeneratorABC',
7+
]

dataflow/example/ReasoningPipeline/pipeline_math_short.json

Lines changed: 260 additions & 0 deletions
Large diffs are not rendered by default.

dataflow/utils/APIGenerator_aisuite.py renamed to dataflow/generators/APIGenerator_aisuite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import aisuite as ai
33
import pandas as pd
44
from tqdm import tqdm
5-
from dataflow.utils.Generator import Generator
5+
from dataflow.core import GeneratorABC
66
from dataflow.utils.Storage import FileStorage
77
from typing import List
88

9-
class APIGenerator_aisuite(Generator):
9+
class APIGenerator_aisuite(GeneratorABC):
1010
def __init__(self, config: dict):
1111
configs = config # Assuming config.configs is a list of configurations
1212

dataflow/utils/APIGenerator_request.py renamed to dataflow/generators/APIGenerator_request.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,37 @@
22
import requests
33
import os
44
import logging
5-
import pandas as pd
65
from concurrent.futures import ThreadPoolExecutor, as_completed
76
from tqdm import tqdm
8-
from dataflow.utils.Storage import FileStorage
9-
from dataflow.utils.Generator import Generator
7+
from dataflow.core import GeneratorABC
108
import re
119

12-
class APIGenerator_request(Generator):
13-
def __init__(self, config: dict):
14-
self.config = config
15-
10+
class APIGenerator_request(GeneratorABC):
11+
"""Use OpenAI API to generate responses based on input messages.
12+
"""
13+
def __init__(self,
14+
api_url: str = "https://api.openai.com/v1/chat/completions",
15+
model_name: str = "gpt-4o",
16+
max_workers: int = 10
17+
):
1618
# Get API key from environment variable or config
17-
self.api_url = self.config.get("api_url", "https://api.openai.com/v1/chat/completions")
19+
self.api_url = api_url
20+
self.model_name = model_name
21+
self.max_workers = max_workers
22+
23+
# config api_key in os.environ global, since safty issue.
1824
self.api_key = os.environ.get("API_KEY")
1925
if self.api_key is None:
2026
raise ValueError("Lack of API_KEY")
2127

22-
self.datastorage = FileStorage(self.config)
23-
28+
"""corden due to I don't confindent about implementation——Sunnyhaze
2429
def check_config(self):
2530
# Ensure all necessary keys are in the config
2631
necessary_keys = ['input_file', 'output_file', 'input_key', 'output_key', 'max_workers']
2732
for key in necessary_keys:
2833
if key not in self.config:
2934
raise ValueError(f"Key {key} is missing in the config")
30-
35+
"""
3136
def format_response(self, response: dict) -> str:
3237
# check if content is formatted like <think>...</think>...<answer>...</answer>
3338
content = response['choices'][0]['message']['content']
@@ -74,6 +79,9 @@ def api_chat(self, system_info: str, messages: str, model: str):
7479
logging.error(f"API request error: {e}")
7580
return None
7681

82+
def generate(self):
83+
pass # for develop, TODO
84+
""" Corden due to I don't confindent about implementation——Sunnyhaze
7785
def generate(self):
7886
self.check_config()
7987
# Read input file
@@ -105,8 +113,8 @@ def generate(self):
105113
raw_dataframe[self.config['output_key']] = responses
106114
self.datastorage.write(self.config['output_file'], raw_dataframe)
107115
return
108-
109-
def generate_from_input(self, input: list[str]) -> list[str]:
116+
"""
117+
def generate_from_input(self, input: list[str], system_prompt: str = "") -> list[str]:
110118
def api_chat_with_id(system_info: str, messages: str, model: str, id):
111119
try:
112120
payload = json.dumps({
@@ -136,17 +144,18 @@ def api_chat_with_id(system_info: str, messages: str, model: str, id):
136144
logging.error(f"API request error: {e}")
137145
return id,None
138146
responses = [None] * len(input)
139-
147+
# -- end of subfunction api_chat_with_id --
148+
140149
# 使用 ThreadPoolExecutor 并行处理多个问题
141150
# logging.info(f"Generating {len(questions)} responses")
142-
with ThreadPoolExecutor(max_workers=self.config['max_workers']) as executor:
151+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
143152
futures = [
144153
executor.submit(
145154
api_chat_with_id,
146-
self.config['system_prompt'],
147-
question,
148-
self.config['model_name'],
149-
idx
155+
system_info = system_prompt,
156+
messages = question,
157+
model = self.model_name,
158+
id = idx
150159
) for idx, question in enumerate(input)
151160
]
152161
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating......"):

dataflow/utils/LocalModelGenerator.py renamed to dataflow/generators/LocalModelGenerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from huggingface_hub import snapshot_download
44
import pandas as pd
55
from dataflow.utils.Storage import FileStorage
6-
from dataflow.utils.Generator import Generator
6+
from dataflow.core import GeneratorABC
77

8-
class LocalModelGenerator(Generator):
8+
class LocalModelGenerator(GeneratorABC):
99
'''
1010
A class for generating text using vllm, with model from huggingface or local directory
1111
'''

dataflow/generators/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .APIGenerator_aisuite import APIGenerator_aisuite
2+
from .APIGenerator_request import APIGenerator_request
3+
from .LocalModelGenerator import LocalModelGenerator
4+
5+
__all__ = [
6+
"APIGenerator_aisuite",
7+
"APIGenerator_request",
8+
"LocalModelGenerator"
9+
]

0 commit comments

Comments
 (0)