2626)
2727from sagemaker .core .telemetry .telemetry_logging import _telemetry_emitter
2828from sagemaker .core .telemetry .constants import Feature
29- from sagemaker .train .constants import HUB_NAME
29+ from sagemaker .train .constants import HUB_NAME , _ALLOWED_REWARD_MODEL_IDS
3030
3131logger = logging .getLogger (__name__ )
3232
@@ -87,7 +87,6 @@ class RLAIFTrainer(BaseTrainer):
8787 ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
8888 reward_model_id (str):
8989 Bedrock model identifier for generating LLM feedback.
90- Evaluator models available: https://docs.aws.amazon.com/bedrock/latest/userguide/evaluation-judge.html
9190 Required for RLAIF training to provide reward signals.
9291 reward_prompt (Union[str, Evaluator]):
9392 The reward prompt or evaluator for AI feedback generation.
@@ -141,7 +140,7 @@ def __init__(
141140 self .training_type = training_type
142141 self .model_package_group_name = _validate_and_resolve_model_package_group (model ,
143142 model_package_group_name )
144- self .reward_model_id = reward_model_id
143+ self .reward_model_id = self . _validate_reward_model_id ( reward_model_id )
145144 self .reward_prompt = reward_prompt
146145 self .mlflow_resource_arn = mlflow_resource_arn
147146 self .mlflow_experiment_name = mlflow_experiment_name
@@ -165,6 +164,18 @@ def __init__(
165164
166165 # Process reward_prompt parameter
167166 self ._process_hyperparameters ()
167+
168+ def _validate_reward_model_id (self , reward_model_id ):
169+ """Validate reward_model_id is one of the allowed values."""
170+ if not reward_model_id :
171+ return None
172+
173+ if reward_model_id not in _ALLOWED_REWARD_MODEL_IDS :
174+ raise ValueError (
175+ f"Invalid reward_model_id '{ reward_model_id } '. "
176+ f"Available models are: { _ALLOWED_REWARD_MODEL_IDS } "
177+ )
178+ return reward_model_id
168179
169180
170181 @_telemetry_emitter (feature = Feature .MODEL_CUSTOMIZATION , func_name = "RLAIFTrainer.train" )
0 commit comments