Skip to content

Commit dafa811

Browse files
authored
Harden Azure ML url validation (#88)
This PR hardens the URL validation for the AzureML LLM endpoint. It also migrates the validator method to the Pydantic V2 style.
1 parent 7de6c0d commit dafa811

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

libs/community/langchain_community/llms/azureml_endpoint.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
from abc import abstractmethod
55
from enum import Enum
66
from typing import Any, Dict, List, Mapping, Optional
7+
from urllib.parse import urlparse
78

89
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
910
from langchain_core.language_models.llms import BaseLLM
1011
from langchain_core.outputs import Generation, LLMResult
1112
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
12-
from pydantic import BaseModel, ConfigDict, SecretStr, model_validator, validator
13+
from pydantic import (
14+
BaseModel,
15+
ConfigDict,
16+
SecretStr,
17+
field_validator,
18+
model_validator,
19+
validator,
20+
)
1321

1422
DEFAULT_TIMEOUT = 50
1523

@@ -431,43 +439,42 @@ def validate_content_formatter(
431439
)
432440
return field_value
433441

434-
@validator("endpoint_url")
435-
def validate_endpoint_url(cls, field_value: Any) -> str:
442+
@field_validator("endpoint_url", mode="after")
443+
@classmethod
444+
def validate_endpoint_url(cls, value: str) -> str:
436445
"""Validate that endpoint url is complete."""
437-
if field_value.endswith("/"):
438-
field_value = field_value[:-1]
439-
if field_value.endswith("inference.ml.azure.com"):
446+
if value.endswith("/"): # trim trailing slash
447+
value = value[:-1]
448+
url = urlparse(value)
449+
if not url.path or url.path == "/":
440450
raise ValueError(
441451
"`endpoint_url` should contain the full invocation URL including "
442452
"`/score` for `endpoint_api_type='dedicated'` or `/completions` "
443453
"or `/models/chat/completions` "
444454
"for `endpoint_api_type='serverless'`"
445455
)
446-
return field_value
456+
return value
447457

448458
@validator("endpoint_api_type")
449459
def validate_endpoint_api_type(
450460
cls, field_value: Any, values: Dict
451461
) -> AzureMLEndpointApiType:
452462
"""Validate that endpoint api type is compatible with the URL format."""
453-
endpoint_url = values.get("endpoint_url")
463+
endpoint_url = urlparse(values.get("endpoint_url"))
454464
if (
455-
(
456-
field_value == AzureMLEndpointApiType.dedicated
457-
or field_value == AzureMLEndpointApiType.realtime
458-
)
459-
and not endpoint_url.endswith("/score") # type: ignore[union-attr]
460-
):
465+
field_value == AzureMLEndpointApiType.dedicated
466+
or field_value == AzureMLEndpointApiType.realtime
467+
) and not endpoint_url.path == "/score":
461468
raise ValueError(
462469
"Endpoints of type `dedicated` should follow the format "
463470
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
464471
" If your endpoint URL ends with `/completions` or"
465472
"`/models/chat/completions`,"
466473
"use `endpoint_api_type='serverless'` instead."
467474
)
468-
if field_value == AzureMLEndpointApiType.serverless and not (
469-
endpoint_url.endswith("/completions") # type: ignore[union-attr]
470-
or endpoint_url.endswith("/models/chat/completions") # type: ignore[union-attr]
475+
if (
476+
field_value == AzureMLEndpointApiType.serverless
477+
and endpoint_url.path not in ["/completions", "/models/chat/completions"]
471478
):
472479
raise ValueError(
473480
"Endpoints of type `serverless` should follow the format "

libs/community/tests/integration_tests/llms/test_azureml_endpoint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,16 @@ def format_response_payload(self, output: bytes) -> str: # type: ignore[overrid
127127
llm.invoke("Foo")
128128

129129

130-
def test_incorrect_url() -> None:
130+
@pytest.mark.parametrize(
131+
"endpoint_url",
132+
["https://endpoint.inference.com", "https://endpoint.inference.com/"],
133+
)
134+
def test_incorrect_url(endpoint_url: str) -> None:
131135
"""Testing AzureML Endpoint for an incorrect URL"""
132136
with pytest.raises(ValidationError):
133137
llm = AzureMLOnlineEndpoint(
134138
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"), # type: ignore[arg-type]
135-
endpoint_url="https://endpoint.inference.com",
139+
endpoint_url=endpoint_url,
136140
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), # type: ignore[arg-type]
137141
content_formatter=OSSContentFormatter(),
138142
)

0 commit comments

Comments
 (0)