Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pydantic_settings/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations as _annotations

import json
import os
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast, get_args
Expand Down Expand Up @@ -196,11 +196,17 @@ class ConfigFileSourceMixin(ABC):
def _read_files(self, files: PathType | None, deep_merge: bool = False) -> dict[str, Any]:
if files is None:
return {}
if isinstance(files, (str, os.PathLike)):
if not isinstance(files, Sequence) or isinstance(files, str):
files = [files]
vars: dict[str, Any] = {}
for file in files:
file_path = Path(file).expanduser()
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
if isinstance(file_path, Path):
file_path = file_path.expanduser()

if not file_path.is_file():
continue

Expand Down
2 changes: 1 addition & 1 deletion pydantic_settings/sources/providers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
super().__init__(settings_cls, self.json_data)

def _read_file(self, file_path: Path) -> dict[str, Any]:
with open(file_path, encoding=self.json_file_encoding) as json_file:
with file_path.open(encoding=self.json_file_encoding) as json_file:
return json.load(json_file)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_settings/sources/providers/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(

def _read_file(self, file_path: Path) -> dict[str, Any]:
import_toml()
with open(file_path, mode='rb') as toml_file:
with file_path.open(mode='rb') as toml_file:
if sys.version_info < (3, 11):
return tomli.load(toml_file)
return tomllib.load(toml_file)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_settings/sources/providers/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(

def _read_file(self, file_path: Path) -> dict[str, Any]:
import_yaml()
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
with file_path.open(encoding=self.yaml_file_encoding) as yaml_file:
return yaml.safe_load(yaml_file) or {}

def __repr__(self) -> str:
Expand Down
82 changes: 82 additions & 0 deletions tests/test_source_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
Test pydantic_settings.JsonConfigSettingsSource.
"""

import importlib.resources
import json
import sys

if sys.version_info < (3, 11):
from importlib.abc import Traversable
else:
from importlib.resources.abc import Traversable

from pathlib import Path

import pytest
Expand Down Expand Up @@ -132,3 +140,77 @@ def settings_customise_sources(

s = Settings()
assert s.model_dump() == {'hello': 'world', 'nested': {'foo': 3, 'bar': 2 if deep_merge else 0}}


class TestTraversableSupport:
FILENAME = 'example_test_config.json'

@pytest.fixture(params=['importlib_resources', 'custom'])
def json_config_path(self, request, tmp_path):
tests_package_dir = importlib.resources.files('tests')

if request.param == 'importlib_resources':
# get Traversable object using importlib.resources
return tests_package_dir / self.FILENAME

# Create a custom Traversable implementation
class CustomTraversable(Traversable):
def __init__(self, path):
self._path = path

def __truediv__(self, child):
return CustomTraversable(self._path / child)

def is_file(self):
return self._path.is_file()

def is_dir(self):
return self._path.is_dir()

def iterdir(self):
raise NotImplementedError('iterdir not implemented for this test')

def open(self, mode='r', *args, **kwargs):
return self._path.open(mode, *args, **kwargs)

def read_bytes(self):
return self._path.read_bytes()

def read_text(self, encoding=None):
return self._path.read_text(encoding=encoding)

@property
def name(self):
return self._path.name

def joinpath(self, *descendants):
return CustomTraversable(self._path.joinpath(*descendants))

custom_traversable = CustomTraversable(tests_package_dir)
return custom_traversable / self.FILENAME

def test_traversable_support(self, json_config_path: Traversable):
assert json_config_path.is_file()

class Settings(BaseSettings):
foobar: str

model_config = SettingsConfigDict(
# Traversable is not added in annotation, but is supported
json_file=json_config_path,
)

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (JsonConfigSettingsSource(settings_cls),)

s = Settings()
# "test" value in file
assert s.foobar == 'test'