Skip to content

Commit 8b33ae1

Browse files
authored
Fix bug in grpo reward module import (axolotl-ai-cloud#2571)
1 parent dc4da4a commit 8b33ae1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/axolotl/core/trainers/grpo/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
135135
try:
136136
# use importlib to dynamically load the reward function from the module
137137
reward_func_module_name = reward_func_fqn.split(".")[-1]
138-
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
138+
reward_func_module = importlib.import_module(
139+
".".join(reward_func_fqn.split(".")[:-1])
140+
)
139141
reward_func = getattr(reward_func_module, reward_func_module_name)
140142
if not len(inspect.signature(reward_func).parameters) >= 2:
141143
raise ValueError(

0 commit comments

Comments
 (0)