Skip to content

Commit 08e0ddf

Browse files
authored
Liqun/refactor role (#441)
refactored the load example/experience functions to abstract a preparation process.
1 parent 9ddf42c commit 08e0ddf

File tree

1 file changed

+84
-72
lines changed

1 file changed

+84
-72
lines changed

taskweaver/role/role.py

Lines changed: 84 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os.path
33
from dataclasses import dataclass
44
from datetime import timedelta
5-
from typing import List, Optional, Set, Tuple, Union
5+
from typing import List, Literal, Optional, Set, Tuple, Union
66

77
from injector import Module, inject, provider
88

@@ -153,101 +153,113 @@ def format_experience(
153153
else ""
154154
)
155155

156+
def prepare_loading(
157+
self,
158+
use_flag: bool,
159+
dynamic_sub_path: bool,
160+
base_path: str,
161+
memory: Optional[Memory],
162+
loaded_from_attr: str,
163+
item_type: Literal["experience", "example"],
164+
) -> Optional[str]:
165+
"""Prepare for loading by checking configurations and memory, and return load_from path if applicable."""
166+
if not use_flag:
167+
setattr(self, f"{item_type}s", [])
168+
return None
169+
170+
if not os.path.exists(base_path):
171+
raise FileNotFoundError(
172+
f"The default {item_type} base path {base_path} does not exist."
173+
f"The original {item_type} base paths have been changed to `{item_type}s` folder."
174+
f"Please migrate the {item_type}s to the new base path.",
175+
)
176+
177+
sub_path = ""
178+
if dynamic_sub_path:
179+
assert memory is not None, f"Memory should be provided when dynamic_{item_type}_sub_path is True"
180+
sub_paths = memory.get_shared_memory_entries(entry_type=f"{item_type}_sub_path")
181+
if sub_paths:
182+
self.tracing.set_span_attribute(f"{item_type}_sub_path", str(sub_paths))
183+
# todo: handle multiple sub paths
184+
sub_path = sub_paths[0].content
185+
else:
186+
self.logger.info(f"No {item_type} sub path found in memory.")
187+
setattr(self, f"{item_type}s", [])
188+
return None
189+
190+
load_from = os.path.join(base_path, sub_path)
191+
if getattr(self, loaded_from_attr) is not None and getattr(self, loaded_from_attr) == load_from:
192+
self.logger.info(f"{item_type.capitalize()} already loaded from {load_from}.")
193+
return None
194+
195+
setattr(self, loaded_from_attr, load_from)
196+
return sub_path
197+
156198
def role_load_experience(
157199
self,
158200
query: str,
159201
memory: Optional[Memory] = None,
160202
) -> None:
161-
if not self.config.use_experience:
162-
self.experiences = []
203+
sub_path = self.prepare_loading(
204+
self.config.use_experience,
205+
self.config.dynamic_experience_sub_path,
206+
self.config.experience_dir,
207+
memory,
208+
"experience_loaded_from",
209+
"experience",
210+
)
211+
if sub_path is None:
163212
return
164213

165214
if self.experience_generator is None:
166215
raise ValueError(
167216
"Experience generator is not initialized. Each role instance should have its own generator.",
168217
)
169218

170-
experience_sub_path = ""
171-
if self.config.dynamic_experience_sub_path:
172-
assert memory is not None, "Memory should be provided when dynamic_experience_sub_path is True"
173-
experience_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path")
174-
if experience_sub_paths:
175-
self.tracing.set_span_attribute("experience_sub_path", str(experience_sub_paths))
176-
# todo: handle multiple experience sub paths
177-
experience_sub_path = experience_sub_paths[0].content
178-
else:
179-
self.logger.info("No experience sub path found in memory.")
180-
self.experiences = []
181-
return
182-
183-
load_from = os.path.join(self.config.experience_dir, experience_sub_path)
184-
if self.experience_loaded_from is None or self.experience_loaded_from != load_from:
185-
self.experience_loaded_from = load_from
186-
self.experience_generator.set_experience_dir(self.config.experience_dir)
187-
self.experience_generator.set_sub_path(experience_sub_path)
188-
self.experience_generator.refresh()
189-
self.experience_generator.load_experience()
190-
self.logger.info(
191-
"Experience loaded successfully for {}, there are {} experiences with filter [{}]".format(
192-
self.alias,
193-
len(self.experience_generator.experience_list),
194-
experience_sub_path,
195-
),
196-
)
197-
else:
198-
self.logger.info(f"Experience already loaded from {load_from}.")
219+
self.experience_generator.set_experience_dir(self.config.experience_dir)
220+
self.experience_generator.set_sub_path(sub_path)
221+
self.experience_generator.refresh()
222+
self.experience_generator.load_experience()
223+
self.logger.info(
224+
"Experience loaded successfully for {}, there are {} experiences with filter [{}]".format(
225+
self.alias,
226+
len(self.experience_generator.experience_list),
227+
sub_path,
228+
),
229+
)
199230

200231
experiences = self.experience_generator.retrieve_experience(query)
201232
self.logger.info(f"Retrieved {len(experiences)} experiences for query [{query}]")
202233
self.experiences = [exp for exp, _ in experiences]
203234

204-
# todo: `role_load_example` is similar to `role_load_experience`, consider refactoring
205235
def role_load_example(
206236
self,
207237
role_set: Set[str],
208238
memory: Optional[Memory] = None,
209239
) -> None:
210-
if not self.config.use_example:
211-
self.examples = []
240+
sub_path = self.prepare_loading(
241+
self.config.use_example,
242+
self.config.dynamic_example_sub_path,
243+
self.config.example_base_path,
244+
memory,
245+
"example_loaded_from",
246+
"example",
247+
)
248+
if sub_path is None:
212249
return
213250

214-
if not os.path.exists(self.config.example_base_path):
215-
raise FileNotFoundError(
216-
f"The default example base path {self.config.example_base_path} does not exist."
217-
"The original example base paths have been changed to `examples` folder."
218-
"Please migrate the examples to the new base path.",
219-
)
220-
221-
example_sub_path = ""
222-
if self.config.dynamic_example_sub_path:
223-
assert memory is not None, "Memory should be provided when dynamic_example_sub_path is True"
224-
example_sub_paths = memory.get_shared_memory_entries(entry_type="example_sub_path")
225-
if example_sub_paths:
226-
self.tracing.set_span_attribute("example_sub_path", str(example_sub_paths))
227-
# todo: handle multiple sub paths
228-
example_sub_path = example_sub_paths[0].content
229-
else:
230-
self.logger.info("No example sub path found in memory.")
231-
self.examples = []
232-
return
233-
234-
load_from = os.path.join(self.config.example_base_path, example_sub_path)
235-
if self.example_loaded_from is None or self.example_loaded_from != load_from:
236-
self.example_loaded_from = load_from
237-
self.examples = load_examples(
238-
folder=self.config.example_base_path,
239-
sub_path=example_sub_path,
240-
role_set=role_set,
241-
)
242-
self.logger.info(
243-
"Example loaded successfully for {}, there are {} examples with filter [{}]".format(
244-
self.alias,
245-
len(self.examples),
246-
example_sub_path,
247-
),
248-
)
249-
else:
250-
self.logger.info(f"Example already loaded from {load_from}.")
251+
self.examples = load_examples(
252+
folder=self.config.example_base_path,
253+
sub_path=sub_path,
254+
role_set=role_set,
255+
)
256+
self.logger.info(
257+
"Example loaded successfully for {}, there are {} examples with filter [{}]".format(
258+
self.alias,
259+
len(self.examples),
260+
sub_path,
261+
),
262+
)
251263

252264

253265
class RoleModuleConfig(ModuleConfig):

0 commit comments

Comments
 (0)