Skip to content

Commit c71e957

Browse files
author
Roja Reddy Sareddy
committed
Fix: Add validation to bedrock reward models
1 parent 9679c5d commit c71e957

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
HUB_NAME = "SageMakerPublicHub"
4444

4545
# Allowed reward model IDs for RLAIF trainer
46-
ALLOWED_REWARD_MODEL_IDS = [
46+
_ALLOWED_REWARD_MODEL_IDS = [
4747
"openai.gpt-oss-120b-1:0",
4848
"openai.gpt-oss-20b-1:0",
4949
"qwen.qwen3-32b-v1:0",

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2828
from sagemaker.core.telemetry.constants import Feature
29-
from sagemaker.train.constants import HUB_NAME, ALLOWED_REWARD_MODEL_IDS
29+
from sagemaker.train.constants import HUB_NAME, _ALLOWED_REWARD_MODEL_IDS
3030

3131
logger = logging.getLogger(__name__)
3232

@@ -170,10 +170,10 @@ def _validate_reward_model_id(self, reward_model_id):
170170
if not reward_model_id:
171171
return None
172172

173-
if reward_model_id not in ALLOWED_REWARD_MODEL_IDS:
173+
if reward_model_id not in _ALLOWED_REWARD_MODEL_IDS:
174174
raise ValueError(
175175
f"Invalid reward_model_id '{reward_model_id}'. "
176-
f"Available models are: {ALLOWED_REWARD_MODEL_IDS}"
176+
f"Available models are: {_ALLOWED_REWARD_MODEL_IDS}"
177177
)
178178
return reward_model_id
179179

0 commit comments

Comments
 (0)