|
4 | 4 | from abc import abstractmethod
|
5 | 5 | from enum import Enum
|
6 | 6 | from typing import Any, Dict, List, Mapping, Optional
|
| 7 | +from urllib.parse import urlparse |
7 | 8 |
|
8 | 9 | from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
9 | 10 | from langchain_core.language_models.llms import BaseLLM
|
10 | 11 | from langchain_core.outputs import Generation, LLMResult
|
11 | 12 | from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
12 |
| -from pydantic import BaseModel, ConfigDict, SecretStr, model_validator, validator |
| 13 | +from pydantic import ( |
| 14 | + BaseModel, |
| 15 | + ConfigDict, |
| 16 | + SecretStr, |
| 17 | + field_validator, |
| 18 | + model_validator, |
| 19 | + validator, |
| 20 | +) |
13 | 21 |
|
14 | 22 | DEFAULT_TIMEOUT = 50
|
15 | 23 |
|
@@ -431,43 +439,42 @@ def validate_content_formatter(
|
431 | 439 | )
|
432 | 440 | return field_value
|
433 | 441 |
|
434 |
| - @validator("endpoint_url") |
435 |
| - def validate_endpoint_url(cls, field_value: Any) -> str: |
| 442 | + @field_validator("endpoint_url", mode="after") |
| 443 | + @classmethod |
| 444 | + def validate_endpoint_url(cls, value: str) -> str: |
436 | 445 | """Validate that endpoint url is complete."""
|
437 |
| - if field_value.endswith("/"): |
438 |
| - field_value = field_value[:-1] |
439 |
| - if field_value.endswith("inference.ml.azure.com"): |
| 446 | + if value.endswith("/"): # trim trailing slash |
| 447 | + value = value[:-1] |
| 448 | + url = urlparse(value) |
| 449 | + if not url.path or url.path == "/": |
440 | 450 | raise ValueError(
|
441 | 451 | "`endpoint_url` should contain the full invocation URL including "
|
442 | 452 | "`/score` for `endpoint_api_type='dedicated'` or `/completions` "
|
443 | 453 | "or `/models/chat/completions` "
|
444 | 454 | "for `endpoint_api_type='serverless'`"
|
445 | 455 | )
|
446 |
| - return field_value |
| 456 | + return value |
447 | 457 |
|
448 | 458 | @validator("endpoint_api_type")
|
449 | 459 | def validate_endpoint_api_type(
|
450 | 460 | cls, field_value: Any, values: Dict
|
451 | 461 | ) -> AzureMLEndpointApiType:
|
452 | 462 | """Validate that endpoint api type is compatible with the URL format."""
|
453 |
| - endpoint_url = values.get("endpoint_url") |
| 463 | + endpoint_url = urlparse(values.get("endpoint_url")) |
454 | 464 | if (
|
455 |
| - ( |
456 |
| - field_value == AzureMLEndpointApiType.dedicated |
457 |
| - or field_value == AzureMLEndpointApiType.realtime |
458 |
| - ) |
459 |
| - and not endpoint_url.endswith("/score") # type: ignore[union-attr] |
460 |
| - ): |
| 465 | + field_value == AzureMLEndpointApiType.dedicated |
| 466 | + or field_value == AzureMLEndpointApiType.realtime |
| 467 | + ) and not endpoint_url.path == "/score": |
461 | 468 | raise ValueError(
|
462 | 469 | "Endpoints of type `dedicated` should follow the format "
|
463 | 470 | "`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
|
464 | 471 | " If your endpoint URL ends with `/completions` or"
|
465 | 472 | "`/models/chat/completions`,"
|
466 | 473 | "use `endpoint_api_type='serverless'` instead."
|
467 | 474 | )
|
468 |
| - if field_value == AzureMLEndpointApiType.serverless and not ( |
469 |
| - endpoint_url.endswith("/completions") # type: ignore[union-attr] |
470 |
| - or endpoint_url.endswith("/models/chat/completions") # type: ignore[union-attr] |
| 475 | + if ( |
| 476 | + field_value == AzureMLEndpointApiType.serverless |
| 477 | + and endpoint_url.path not in ["/completions", "/models/chat/completions"] |
471 | 478 | ):
|
472 | 479 | raise ValueError(
|
473 | 480 | "Endpoints of type `serverless` should follow the format "
|
|
0 commit comments