Skip to content

Commit 3219420

Browse files
authored
Add non-Path files support (for example Traversable) and open files using Path.open method (#724)
1 parent c158510 commit 3219420

File tree

5 files changed

+94
-6
lines changed

5 files changed

+94
-6
lines changed

pydantic_settings/sources/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from __future__ import annotations as _annotations
44

55
import json
6-
import os
76
from abc import ABC, abstractmethod
7+
from collections.abc import Sequence
88
from dataclasses import asdict, is_dataclass
99
from pathlib import Path
1010
from typing import TYPE_CHECKING, Any, cast, get_args
@@ -196,11 +196,17 @@ class ConfigFileSourceMixin(ABC):
196196
def _read_files(self, files: PathType | None, deep_merge: bool = False) -> dict[str, Any]:
197197
if files is None:
198198
return {}
199-
if isinstance(files, (str, os.PathLike)):
199+
if not isinstance(files, Sequence) or isinstance(files, str):
200200
files = [files]
201201
vars: dict[str, Any] = {}
202202
for file in files:
203-
file_path = Path(file).expanduser()
203+
if isinstance(file, str):
204+
file_path = Path(file)
205+
else:
206+
file_path = file
207+
if isinstance(file_path, Path):
208+
file_path = file_path.expanduser()
209+
204210
if not file_path.is_file():
205211
continue
206212

pydantic_settings/sources/providers/json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
super().__init__(settings_cls, self.json_data)
3939

4040
def _read_file(self, file_path: Path) -> dict[str, Any]:
41-
with open(file_path, encoding=self.json_file_encoding) as json_file:
41+
with file_path.open(encoding=self.json_file_encoding) as json_file:
4242
return json.load(json_file)
4343

4444
def __repr__(self) -> str:

pydantic_settings/sources/providers/toml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858

5959
def _read_file(self, file_path: Path) -> dict[str, Any]:
6060
import_toml()
61-
with open(file_path, mode='rb') as toml_file:
61+
with file_path.open(mode='rb') as toml_file:
6262
if sys.version_info < (3, 11):
6363
return tomli.load(toml_file)
6464
return tomllib.load(toml_file)

pydantic_settings/sources/providers/yaml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666

6767
def _read_file(self, file_path: Path) -> dict[str, Any]:
6868
import_yaml()
69-
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
69+
with file_path.open(encoding=self.yaml_file_encoding) as yaml_file:
7070
return yaml.safe_load(yaml_file) or {}
7171

7272
def __repr__(self) -> str:

tests/test_source_json.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
Test pydantic_settings.JsonConfigSettingsSource.
33
"""
44

5+
import importlib.resources
56
import json
7+
import sys
8+
9+
if sys.version_info < (3, 11):
10+
from importlib.abc import Traversable
11+
else:
12+
from importlib.resources.abc import Traversable
13+
614
from pathlib import Path
715

816
import pytest
@@ -132,3 +140,77 @@ def settings_customise_sources(
132140

133141
s = Settings()
134142
assert s.model_dump() == {'hello': 'world', 'nested': {'foo': 3, 'bar': 2 if deep_merge else 0}}
143+
144+
145+
class TestTraversableSupport:
146+
FILENAME = 'example_test_config.json'
147+
148+
@pytest.fixture(params=['importlib_resources', 'custom'])
149+
def json_config_path(self, request, tmp_path):
150+
tests_package_dir = importlib.resources.files('tests')
151+
152+
if request.param == 'importlib_resources':
153+
# get Traversable object using importlib.resources
154+
return tests_package_dir / self.FILENAME
155+
156+
# Create a custom Traversable implementation
157+
class CustomTraversable(Traversable):
158+
def __init__(self, path):
159+
self._path = path
160+
161+
def __truediv__(self, child):
162+
return CustomTraversable(self._path / child)
163+
164+
def is_file(self):
165+
return self._path.is_file()
166+
167+
def is_dir(self):
168+
return self._path.is_dir()
169+
170+
def iterdir(self):
171+
raise NotImplementedError('iterdir not implemented for this test')
172+
173+
def open(self, mode='r', *args, **kwargs):
174+
return self._path.open(mode, *args, **kwargs)
175+
176+
def read_bytes(self):
177+
return self._path.read_bytes()
178+
179+
def read_text(self, encoding=None):
180+
return self._path.read_text(encoding=encoding)
181+
182+
@property
183+
def name(self):
184+
return self._path.name
185+
186+
def joinpath(self, *descendants):
187+
return CustomTraversable(self._path.joinpath(*descendants))
188+
189+
custom_traversable = CustomTraversable(tests_package_dir)
190+
return custom_traversable / self.FILENAME
191+
192+
def test_traversable_support(self, json_config_path: Traversable):
193+
assert json_config_path.is_file()
194+
195+
class Settings(BaseSettings):
196+
foobar: str
197+
198+
model_config = SettingsConfigDict(
199+
# Traversable is not added in annotation, but is supported
200+
json_file=json_config_path,
201+
)
202+
203+
@classmethod
204+
def settings_customise_sources(
205+
cls,
206+
settings_cls: type[BaseSettings],
207+
init_settings: PydanticBaseSettingsSource,
208+
env_settings: PydanticBaseSettingsSource,
209+
dotenv_settings: PydanticBaseSettingsSource,
210+
file_secret_settings: PydanticBaseSettingsSource,
211+
) -> tuple[PydanticBaseSettingsSource, ...]:
212+
return (JsonConfigSettingsSource(settings_cls),)
213+
214+
s = Settings()
215+
# "test" value in file
216+
assert s.foobar == 'test'

0 commit comments

Comments
 (0)