We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dc4da4a commit 8b33ae1Copy full SHA for 8b33ae1
src/axolotl/core/trainers/grpo/__init__.py
@@ -135,7 +135,9 @@ def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
135
try:
136
# use importlib to dynamically load the reward function from the module
137
reward_func_module_name = reward_func_fqn.split(".")[-1]
138
- reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
+ reward_func_module = importlib.import_module(
139
+ ".".join(reward_func_fqn.split(".")[:-1])
140
+ )
141
reward_func = getattr(reward_func_module, reward_func_module_name)
142
if not len(inspect.signature(reward_func).parameters) >= 2:
143
raise ValueError(
0 commit comments