Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions pgserviceparser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import configparser
import io
import platform
import stat
from os import getenv
from pathlib import Path
from typing import Optional
Expand All @@ -10,6 +11,51 @@
from .exceptions import ServiceFileNotFound, ServiceNotFound


def _make_file_writable(path: Path):
"""Attempt to add write permissions to a file.

Args:
path: path to the file

Raises:
PermissionError: when the file permissions cannot be changed
"""
current_permission = stat.S_IMODE(path.stat().st_mode)
WRITE = stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH
path.chmod(current_permission | WRITE)


def _when_read_only_try_to_add_write_permission(func):
"""Decorator for functions that attempt to modify the service file.

If the file is read-only, a PermissionError exception will be raised.
This decorator handles that error by attempting to set write permissions
(which works if the user is the owner of the file or has proper rights to
alter the file permissions), and rerunning the decorated function.

If the user cannot modify permissions on the file, the PermissionError
is re-raised.
"""

def wrapper(*args, **kwargs):
attempt = 0
while attempt <= 1:
try:
return func(*args, **kwargs)
except PermissionError:
if attempt == 1:
raise

try:
_make_file_writable(conf_path())
except PermissionError:
pass
finally:
attempt += 1

return wrapper


def conf_path(create_if_missing: Optional[bool] = False) -> Path:
"""Returns the path found for the pg_service.conf on the system as string.

Expand Down Expand Up @@ -66,6 +112,7 @@ def full_config(conf_file_path: Optional[Path] = None) -> configparser.ConfigPar
return config


@_when_read_only_try_to_add_write_permission
def remove_service(service_name: str, conf_file_path: Optional[Path] = None) -> None:
"""Remove a complete service from the service file.

Expand All @@ -92,6 +139,7 @@ def remove_service(service_name: str, conf_file_path: Optional[Path] = None) ->
config.write(configfile, space_around_delimiters=False)


@_when_read_only_try_to_add_write_permission
def rename_service(old_name: str, new_name: str, conf_file_path: Optional[Path] = None) -> None:
"""Rename a service in the service file.

Expand Down Expand Up @@ -124,6 +172,7 @@ def rename_service(old_name: str, new_name: str, conf_file_path: Optional[Path]
config.write(configfile, space_around_delimiters=False)


@_when_read_only_try_to_add_write_permission
def create_service(service_name: str, settings: dict, conf_file_path: Optional[Path] = None) -> bool:
"""Create a new service in the service file.

Expand Down Expand Up @@ -153,6 +202,7 @@ def create_service(service_name: str, settings: dict, conf_file_path: Optional[P
return True


@_when_read_only_try_to_add_write_permission
def copy_service_settings(
source_service_name: str,
target_service_name: str,
Expand Down Expand Up @@ -217,6 +267,7 @@ def service_config(service_name: str, conf_file_path: Optional[Path] = None) ->
return dict(config[service_name])


@_when_read_only_try_to_add_write_permission
def write_service_setting(
service_name: str,
setting_key: str,
Expand Down Expand Up @@ -250,6 +301,7 @@ def write_service_setting(
config.write(configfile, space_around_delimiters=False)


@_when_read_only_try_to_add_write_permission
def write_service(
service_name: str, settings: dict, conf_file_path: Optional[Path] = None, create_if_not_found: bool = False
) -> dict:
Expand Down Expand Up @@ -309,11 +361,13 @@ def write_service_to_text(service_name: str, settings: dict) -> str:
return res.strip()


def service_names(conf_file_path: Optional[Path] = None) -> list[str]:
def service_names(conf_file_path: Optional[Path] = None, sorted_alphabetically: bool = False) -> list[str]:
"""Returns all service names in a list.

Args:
conf_file_path: path to the pg_service.conf. If None the `conf_path()` is used, defaults to None
sorted_alphabetically: whether to sort the names alphabetically (case-insensitive),
defaults to False

Returns:
list of every service registered
Expand All @@ -323,4 +377,5 @@ def service_names(conf_file_path: Optional[Path] = None) -> list[str]:
"""

config = full_config(conf_file_path)
return config.sections()
names = config.sections()
return sorted(names, key=str.lower) if sorted_alphabetically else names
29 changes: 29 additions & 0 deletions test/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import shutil
import stat
import unittest
from pathlib import Path

Expand Down Expand Up @@ -57,6 +58,22 @@ def test_full_config(self):
def test_service_names(self):
self.assertEqual(service_names(), ["service_1", "service_2", "service_3", "service_4"])

def test_service_names_sorted_alphabetically(self):
# Add a service whose name comes before existing ones alphabetically
create_service("Alpha_service", {"host": "localhost"})
create_service("zulu_service", {"host": "localhost"})

# Without sorting, order is as written in the file (appended at the end)
names = service_names()
self.assertEqual(names[-2:], ["Alpha_service", "zulu_service"])

# With sorting, names are case-insensitive alphabetical
sorted_names = service_names(sorted_alphabetically=True)
self.assertEqual(
sorted_names,
["Alpha_service", "service_1", "service_2", "service_3", "service_4", "zulu_service"],
)

def test_service_config(self):
self.assertRaises(ServiceNotFound, service_config, "non_existing_service")

Expand Down Expand Up @@ -228,6 +245,18 @@ def test_remove_service(self):
self.assertIn("service_tmp", service_names())
remove_service("service_tmp")

def test_write_on_read_only_file(self):
# Make the service file read-only
self.service_file_path.chmod(stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)

# The decorator should automatically add write permission and succeed
write_service_setting("service_1", "port", "9999")
conf = service_config("service_1")
self.assertEqual(conf["port"], "9999")

# Verify the file is writable again
self.assertTrue(os.access(self.service_file_path, os.W_OK))

def test_missing_file(self):
another_service_file_path = PGSERVICEPARSER_SRC_PATH / "test" / "data" / "new_folder" / "pgservice.conf"
os.environ["PGSERVICEFILE"] = str(another_service_file_path)
Expand Down