3232import org .elasticsearch .inference .TaskType ;
3333import org .elasticsearch .inference .configuration .SettingsConfigurationFieldType ;
3434import org .elasticsearch .rest .RestStatus ;
35+ import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
3536import org .elasticsearch .xpack .core .inference .chunking .ChunkingSettingsBuilder ;
3637import org .elasticsearch .xpack .core .inference .chunking .EmbeddingRequestChunker ;
3738import org .elasticsearch .xpack .inference .common .amazon .AwsSecretSettings ;
6364import static org .elasticsearch .xpack .inference .services .ServiceFields .DIMENSIONS ;
6465import static org .elasticsearch .xpack .inference .services .ServiceUtils .createInvalidModelException ;
6566import static org .elasticsearch .xpack .inference .services .ServiceUtils .createInvalidTaskTypeException ;
67+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .createUnsupportedTaskTypeStatusException ;
6668import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMap ;
6769import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrDefaultEmpty ;
6870import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrThrowIfNull ;
@@ -89,11 +91,16 @@ public class AmazonBedrockService extends SenderService {
8991
9092 private final Sender amazonBedrockSender ;
9193
92- private static final EnumSet <TaskType > supportedTaskTypes = EnumSet .of (
94+ // The task types exposed via the _inference/_services API
95+ private static final EnumSet <TaskType > SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet .of (
9396 TaskType .TEXT_EMBEDDING ,
9497 TaskType .COMPLETION ,
9598 TaskType .CHAT_COMPLETION
9699 );
100+ /**
101+ * The task types that the {@link InferenceAction.Request} can accept.
102+ */
103+ private static final EnumSet <TaskType > SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet .of (TaskType .TEXT_EMBEDDING , TaskType .COMPLETION );
97104
98105 private static final EnumSet <InputType > VALID_INPUT_TYPE_VALUES = EnumSet .of (
99106 InputType .INGEST ,
@@ -154,6 +161,11 @@ protected void doInfer(
154161 TimeValue timeout ,
155162 ActionListener <InferenceServiceResults > listener
156163 ) {
164+ if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES .contains (model .getTaskType ()) == false ) {
165+ listener .onFailure (createUnsupportedTaskTypeStatusException (model , SUPPORTED_INFERENCE_ACTION_TASK_TYPES ));
166+ return ;
167+ }
168+
157169 if (model instanceof AmazonBedrockModel == false ) {
158170 listener .onFailure (createInvalidModelException (model ));
159171 return ;
@@ -298,7 +310,7 @@ public InferenceServiceConfiguration getConfiguration() {
298310
299311 @ Override
300312 public EnumSet <TaskType > supportedTaskTypes () {
301- return supportedTaskTypes ;
313+ return SUPPORTED_TASK_TYPES_FOR_SERVICES_API ;
302314 }
303315
304316 private static AmazonBedrockModel createModel (
@@ -429,7 +441,9 @@ public static InferenceServiceConfiguration get() {
429441
430442 configurationMap .put (
431443 PROVIDER_FIELD ,
432- new SettingsConfiguration .Builder (supportedTaskTypes ).setDescription ("The model provider for your deployment." )
444+ new SettingsConfiguration .Builder (SUPPORTED_TASK_TYPES_FOR_SERVICES_API ).setDescription (
445+ "The model provider for your deployment."
446+ )
433447 .setLabel ("Provider" )
434448 .setRequired (true )
435449 .setSensitive (false )
@@ -440,7 +454,7 @@ public static InferenceServiceConfiguration get() {
440454
441455 configurationMap .put (
442456 MODEL_FIELD ,
443- new SettingsConfiguration .Builder (supportedTaskTypes ).setDescription (
457+ new SettingsConfiguration .Builder (SUPPORTED_TASK_TYPES_FOR_SERVICES_API ).setDescription (
444458 "The base model ID or an ARN to a custom model based on a foundational model."
445459 )
446460 .setLabel ("Model" )
@@ -453,7 +467,7 @@ public static InferenceServiceConfiguration get() {
453467
454468 configurationMap .put (
455469 REGION_FIELD ,
456- new SettingsConfiguration .Builder (supportedTaskTypes ).setDescription (
470+ new SettingsConfiguration .Builder (SUPPORTED_TASK_TYPES_FOR_SERVICES_API ).setDescription (
457471 "The region that your model or ARN is deployed in."
458472 )
459473 .setLabel ("Region" )
@@ -482,13 +496,13 @@ public static InferenceServiceConfiguration get() {
482496 configurationMap .putAll (
483497 RateLimitSettings .toSettingsConfigurationWithDescription (
484498 "By default, the amazonbedrock service sets the number of requests allowed per minute to 240." ,
485- supportedTaskTypes
499+ SUPPORTED_TASK_TYPES_FOR_SERVICES_API
486500 )
487501 );
488502
489503 return new InferenceServiceConfiguration .Builder ().setService (NAME )
490504 .setName (SERVICE_NAME )
491- .setTaskTypes (supportedTaskTypes )
505+ .setTaskTypes (SUPPORTED_TASK_TYPES_FOR_SERVICES_API )
492506 .setConfigurations (configurationMap )
493507 .build ();
494508 }
0 commit comments