Skip to content

Commit 186fb06

Browse files
committed
Fix checking of supported models
1 parent 76052b2 commit 186fb06

File tree

4 files changed

+36
-28
lines changed

4 files changed

+36
-28
lines changed

tests/external_aiobotocore/test_bedrock_chat_completion_invoke_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
from io import BytesIO
17+
from pprint import pformat
1718

1819
import botocore.eventstream
1920
import botocore.exceptions
@@ -856,7 +857,12 @@ def test_bedrock_chat_completion_functions_marked_as_wrapped_for_sdk_compatibili
856857
def test_chat_models_instrumented(loop):
857858
import aiobotocore
858859

859-
SUPPORTED_MODELS = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" not in model]
860+
def _is_supported_model(model):
861+
supported_models = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" not in model]
862+
for supported_model in supported_models:
863+
if supported_model in model:
864+
return True
865+
return False
860866

861867
_id = os.environ.get("AWS_ACCESS_KEY_ID")
862868
key = os.environ.get("AWS_SECRET_ACCESS_KEY")
@@ -869,12 +875,8 @@ def test_chat_models_instrumented(loop):
869875
try:
870876
response = loop.run_until_complete(client.list_foundation_models(byOutputModality="TEXT"))
871877
models = [model["modelId"] for model in response["modelSummaries"]]
872-
not_supported = []
873-
for model in models:
874-
is_supported = any(model.startswith(supported_model) for supported_model in SUPPORTED_MODELS)
875-
if not is_supported:
876-
not_supported.append(model)
878+
not_supported = [model for model in models if not _is_supported_model(model)]
877879

878-
assert not not_supported, f"The following unsupported models were found: {not_supported}"
880+
assert not not_supported, f"The following unsupported models were found: {pformat(not_supported)}"
879881
finally:
880882
loop.run_until_complete(client.__aexit__(None, None, None))

tests/external_aiobotocore/test_bedrock_embeddings.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
from io import BytesIO
17+
from pprint import pformat
1718

1819
import botocore.exceptions
1920
import pytest
@@ -414,7 +415,12 @@ async def _test():
414415
def test_embedding_models_instrumented(loop):
415416
import aiobotocore
416417

417-
SUPPORTED_MODELS = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" in model]
418+
def _is_supported_model(model):
419+
supported_models = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" in model]
420+
for supported_model in supported_models:
421+
if supported_model in model:
422+
return True
423+
return False
418424

419425
_id = os.environ.get("AWS_ACCESS_KEY_ID")
420426
key = os.environ.get("AWS_SECRET_ACCESS_KEY")
@@ -427,12 +433,8 @@ def test_embedding_models_instrumented(loop):
427433
try:
428434
response = client.list_foundation_models(byOutputModality="EMBEDDING")
429435
models = [model["modelId"] for model in response["modelSummaries"]]
430-
not_supported = []
431-
for model in models:
432-
is_supported = any(model.startswith(supported_model) for supported_model in SUPPORTED_MODELS)
433-
if not is_supported:
434-
not_supported.append(model)
436+
not_supported = [model for model in models if not _is_supported_model(model)]
435437

436-
assert not not_supported, f"The following unsupported models were found: {not_supported}"
438+
assert not not_supported, f"The following unsupported models were found: {pformat(not_supported)}"
437439
finally:
438440
loop.run_until_complete(client.__aexit__(None, None, None))

tests/external_botocore/test_bedrock_chat_completion_invoke_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
from io import BytesIO
17+
from pprint import pformat
1718

1819
import boto3
1920
import botocore.eventstream
@@ -815,7 +816,12 @@ def test_bedrock_chat_completion_functions_marked_as_wrapped_for_sdk_compatibili
815816

816817

817818
def test_chat_models_instrumented():
818-
SUPPORTED_MODELS = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" not in model]
819+
def _is_supported_model(model):
820+
supported_models = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" not in model]
821+
for supported_model in supported_models:
822+
if supported_model in model:
823+
return True
824+
return False
819825

820826
_id = os.environ.get("AWS_ACCESS_KEY_ID")
821827
key = os.environ.get("AWS_SECRET_ACCESS_KEY")
@@ -825,10 +831,6 @@ def test_chat_models_instrumented():
825831
client = boto3.client("bedrock", "us-east-1")
826832
response = client.list_foundation_models(byOutputModality="TEXT")
827833
models = [model["modelId"] for model in response["modelSummaries"]]
828-
not_supported = []
829-
for model in models:
830-
is_supported = any(model.startswith(supported_model) for supported_model in SUPPORTED_MODELS)
831-
if not is_supported:
832-
not_supported.append(model)
834+
not_supported = [model for model in models if not _is_supported_model(model)]
833835

834-
assert not not_supported, f"The following unsupported models were found: {not_supported}"
836+
assert not not_supported, f"The following unsupported models were found: {pformat(not_supported)}"

tests/external_botocore/test_bedrock_embeddings.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
from io import BytesIO
17+
from pprint import pformat
1718

1819
import boto3
1920
import botocore.exceptions
@@ -409,7 +410,12 @@ def _test():
409410

410411

411412
def test_embedding_models_instrumented():
412-
SUPPORTED_MODELS = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" in model]
413+
def _is_supported_model(model):
414+
supported_models = [model for model, _, _, _ in MODEL_EXTRACTORS if "embed" in model]
415+
for supported_model in supported_models:
416+
if supported_model in model:
417+
return True
418+
return False
413419

414420
_id = os.environ.get("AWS_ACCESS_KEY_ID")
415421
key = os.environ.get("AWS_SECRET_ACCESS_KEY")
@@ -419,10 +425,6 @@ def test_embedding_models_instrumented():
419425
client = boto3.client("bedrock", "us-east-1")
420426
response = client.list_foundation_models(byOutputModality="EMBEDDING")
421427
models = [model["modelId"] for model in response["modelSummaries"]]
422-
not_supported = []
423-
for model in models:
424-
is_supported = any(model.startswith(supported_model) for supported_model in SUPPORTED_MODELS)
425-
if not is_supported:
426-
not_supported.append(model)
428+
not_supported = [model for model in models if not _is_supported_model(model)]
427429

428-
assert not not_supported, f"The following unsupported models were found: {not_supported}"
430+
assert not not_supported, f"The following unsupported models were found: {pformat(not_supported)}"

0 commit comments

Comments
 (0)