|
4 | 4 | import json
|
5 | 5 | import logging
|
6 | 6 | from copy import deepcopy
|
7 |
| -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, get_origin |
| 7 | +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast, get_origin |
8 | 8 | from urllib.parse import parse_qs
|
9 | 9 |
|
10 | 10 | from pydantic import BaseModel
|
@@ -456,32 +456,59 @@ def _normalize_multi_params(
|
456 | 456 | """
|
457 | 457 | for param in params:
|
458 | 458 | if is_scalar_field(param):
|
459 |
| - try: |
460 |
| - val = input_dict[param.alias] |
461 |
| - if isinstance(val, list) and len(val) == 1: |
462 |
| - input_dict[param.alias] = val[0] |
463 |
| - elif isinstance(val, list): |
464 |
| - pass # leave as list for multi-value |
465 |
| - # If it's a string, leave as is |
466 |
| - except KeyError: |
467 |
| - pass |
| 459 | + _process_scalar_param(input_dict, param) |
468 | 460 | elif lenient_issubclass(param.field_info.annotation, BaseModel):
|
469 |
| - model_class = param.field_info.annotation |
470 |
| - model_data = {} |
471 |
| - |
472 |
| - for field_name, field_def in model_class.model_fields.items(): |
473 |
| - field_alias = field_def.alias or field_name |
474 |
| - value = input_dict.get(field_alias) |
475 |
| - if value is None and ( |
476 |
| - model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name") |
477 |
| - ): |
478 |
| - value = input_dict.get(field_name) |
479 |
| - if value is not None: |
480 |
| - if get_origin(field_def.annotation) is list: |
481 |
| - model_data[field_alias] = value |
482 |
| - elif isinstance(value, list): |
483 |
| - model_data[field_alias] = value[0] |
484 |
| - else: |
485 |
| - model_data[field_alias] = value |
486 |
| - input_dict[param.alias] = model_data |
| 461 | + _process_model_param(input_dict, param) |
487 | 462 | return input_dict
|
| 463 | + |
| 464 | + |
| 465 | +def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: |
| 466 | + """Process a scalar parameter by normalizing single-item lists.""" |
| 467 | + try: |
| 468 | + val = input_dict[param.alias] |
| 469 | + if isinstance(val, list) and len(val) == 1: |
| 470 | + input_dict[param.alias] = val[0] |
| 471 | + except KeyError: |
| 472 | + pass |
| 473 | + |
| 474 | + |
| 475 | +def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: |
| 476 | + """Process a Pydantic model parameter by extracting model fields.""" |
| 477 | + model_class = cast(type[BaseModel], param.field_info.annotation) |
| 478 | + |
| 479 | + model_data = {} |
| 480 | + for field_name, field_def in model_class.model_fields.items(): |
| 481 | + field_alias = field_def.alias or field_name |
| 482 | + value = _get_param_value(input_dict, field_alias, field_name, model_class) |
| 483 | + |
| 484 | + if value is not None: |
| 485 | + model_data[field_alias] = _normalize_field_value(value, field_def) |
| 486 | + |
| 487 | + input_dict[param.alias] = model_data |
| 488 | + |
| 489 | + |
| 490 | +def _get_param_value( |
| 491 | + input_dict: MutableMapping[str, Any], |
| 492 | + field_alias: str, |
| 493 | + field_name: str, |
| 494 | + model_class: type[BaseModel], |
| 495 | +) -> Any: |
| 496 | + """Get parameter value, checking both alias and field name if needed.""" |
| 497 | + value = input_dict.get(field_alias) |
| 498 | + if value is not None: |
| 499 | + return value |
| 500 | + |
| 501 | + if model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name"): |
| 502 | + value = input_dict.get(field_name) |
| 503 | + |
| 504 | + return value |
| 505 | + |
| 506 | + |
| 507 | +def _normalize_field_value(value: Any, field_def: Any) -> Any: |
| 508 | + """Normalize field value based on its type annotation.""" |
| 509 | + if get_origin(field_def.annotation) is list: |
| 510 | + return value |
| 511 | + elif isinstance(value, list) and value: |
| 512 | + return value[0] |
| 513 | + else: |
| 514 | + return value |
0 commit comments