Skip to content

Commit 7d42c56

Browse files
refactor: using yaml file for config
1 parent 2affb12 commit 7d42c56

File tree

1 file changed

+18
-38
lines changed

1 file changed

+18
-38
lines changed

generate.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,35 +24,24 @@ def save_config(global_config):
2424

2525
if __name__ == '__main__':
2626
parser = argparse.ArgumentParser()
27-
parser.add_argument('--input_file',
28-
help='Raw context jsonl path.',
29-
default='resources/examples/chunked_demo.json',
27+
parser.add_argument('--config_file',
28+
help='Config parameters for GraphGen.',
29+
default='graphgen_config.yaml',
3030
type=str)
31-
parser.add_argument('--data_type',
32-
help='Data type of input file. (Raw context or chunked context)',
33-
choices=['raw', 'chunked'],
34-
default='raw',
35-
type=str)
36-
parser.add_argument('--web_search',
37-
help='Search node info from wiki.',
38-
action='store_true',
39-
default=False)
40-
parser.add_argument('--tokenizer',
41-
help='Tokenizer name.',
42-
default='cl100k_base',
43-
type=str)
44-
4531
args = parser.parse_args()
46-
input_file = args.input_file
32+
with open(args.config_file, "r", encoding='utf-8') as f:
33+
config = yaml.load(f, Loader=yaml.FullLoader)
4734

48-
if args.data_type == 'raw':
35+
input_file = config['input_file']
36+
37+
if config['data_type'] == 'raw':
4938
with open(input_file, "r", encoding='utf-8') as f:
5039
data = [json.loads(line) for line in f]
51-
elif args.data_type == 'chunked':
40+
elif config['data_type'] == 'chunked':
5241
with open(input_file, "r", encoding='utf-8') as f:
5342
data = json.load(f)
5443
else:
55-
raise ValueError(f"Invalid data type: {args.data_type}")
44+
raise ValueError(f"Invalid data type: {config['data_type']}")
5645

5746
synthesizer_llm_client = OpenAIModel(
5847
model_name=os.getenv("TEACHER_MODEL"),
@@ -65,36 +54,27 @@ def save_config(global_config):
6554
base_url=os.getenv("STUDENT_BASE_URL")
6655
)
6756

68-
traverse_strategy = TraverseStrategy()
57+
traverse_strategy = TraverseStrategy(
58+
**config['traverse_strategy']
59+
)
6960

7061
graph_gen = GraphGen(
7162
unique_id=unique_id,
72-
teacher_llm_client=synthesizer_llm_client,
73-
student_llm_client=training_llm_client,
74-
if_web_search=args.web_search,
63+
synthesizer_llm_client=synthesizer_llm_client,
64+
training_llm_client=training_llm_client,
65+
if_web_search=config['web_search'],
7566
tokenizer_instance=Tokenizer(
76-
model_name=args.tokenizer
67+
model_name=config['tokenizer']
7768
),
7869
traverse_strategy=traverse_strategy
7970
)
8071

81-
graph_gen.insert(data, args.data_type)
72+
graph_gen.insert(data, config['data_type'])
8273

8374
graph_gen.quiz(max_samples=2)
8475

8576
graph_gen.judge(re_judge=False)
8677

8778
graph_gen.traverse()
8879

89-
config = {
90-
"unique_id": unique_id,
91-
"input_file": input_file,
92-
"data_type": args.data_type,
93-
"web_search": args.web_search,
94-
"tokenizer": args.tokenizer,
95-
"teacher_model": os.getenv("TEACHER_MODEL"),
96-
"student_model": os.getenv("STUDENT_MODEL"),
97-
}
98-
traverse_strategy_yaml = traverse_strategy.to_yaml()
99-
config.update(traverse_strategy_yaml)
10080
save_config(config)

0 commit comments

Comments
 (0)