Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/changelog/114.miscellaneous.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Feat: add fullpath to set UDS sock filename
68 changes: 48 additions & 20 deletions src/ansys/tools/common/cyberchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def create_channel(
uds_service: str | None = None,
uds_dir: str | Path | None = None,
uds_id: str | None = None,
uds_fullpath: str | Path | None = None,
certs_dir: str | Path | None = None,
cert_files: CertificateFiles | None = None,
grpc_options: list[tuple[str, object]] | None = None,
Expand Down Expand Up @@ -108,6 +109,9 @@ def create_channel(
Optional ID to use for the UDS socket filename.
By default `None` and thus it will use "<uds_service>.sock".
Otherwise, the socket filename will be "<uds_service>-<uds_id>.sock".
uds_fullpath : str | Path | None
Full path to the UDS socket file.
By default `None` and thus it will use the `uds_service`, `uds_dir` and `uds_id` parameters.
certs_dir : str | Path | None
Directory to use for TLS certificates.
By default `None` and thus search for the "ANSYS_GRPC_CERTIFICATES" environment variable.
Expand Down Expand Up @@ -141,7 +145,7 @@ def check_host_port(transport_mode, host, port) -> tuple[str, str, str]:
transport_mode, host, port = check_host_port(transport_mode, host, port)
return create_insecure_channel(host, port, grpc_options)
case "uds":
return create_uds_channel(uds_service, uds_dir, uds_id, grpc_options)
return create_uds_channel(uds_service, uds_dir, uds_id, grpc_options, uds_fullpath)
case "wnua":
transport_mode, host, port = check_host_port(transport_mode, host, port)
return create_wnua_channel(host, port, grpc_options)
Expand Down Expand Up @@ -186,16 +190,17 @@ def create_insecure_channel(


def create_uds_channel(
uds_service: str | None,
uds_service: str | None = None,
uds_dir: str | Path | None = None,
uds_id: str | None = None,
grpc_options: list[tuple[str, object]] | None = None,
uds_fullpath: str | Path | None = None,
) -> grpc.Channel:
"""Create a gRPC channel using Unix Domain Sockets (UDS).

Parameters
----------
uds_service : str
uds_service : str | None
Service name for the UDS socket.
uds_dir : str | Path | None
Directory to use for Unix Domain Sockets (UDS) transport mode.
Expand All @@ -208,6 +213,9 @@ def create_uds_channel(
gRPC channel options to pass when creating the channel.
Each option is a tuple of the form ("option_name", value).
By default `None` and thus only the default authority option is added.
uds_fullpath : str | Path | None
Full path to the UDS socket file.
By default `None` and thus it will use the `uds_service`, `uds_dir` and `uds_id` parameters.

Returns
-------
Expand All @@ -218,18 +226,24 @@ def create_uds_channel(
if not is_uds_supported():
raise RuntimeError("Unix Domain Sockets are not supported on this platform or gRPC version.")

if not uds_service:
raise ValueError("When using UDS transport mode, 'uds_service' must be provided.")
if uds_fullpath:
# Ensure the parent directory exists
Path(uds_fullpath).parent.mkdir(parents=True, exist_ok=True)
target = f"unix:{uds_fullpath}"
else:
if uds_service is None:
raise ValueError("When using UDS transport mode, 'uds_service' must be provided.")

# Determine UDS folder
uds_folder = determine_uds_folder(uds_dir)
# Determine UDS folder
uds_folder = determine_uds_folder(uds_dir)

# Make sure the folder exists
uds_folder.mkdir(parents=True, exist_ok=True)
# Make sure the folder exists
uds_folder.mkdir(parents=True, exist_ok=True)

# Generate socket filename with optional ID
socket_filename = f"{uds_service}-{uds_id}.sock" if uds_id else f"{uds_service}.sock"
target = f"unix:{uds_folder / socket_filename}"

# Generate socket filename with optional ID
socket_filename = f"{uds_service}-{uds_id}.sock" if uds_id else f"{uds_service}.sock"
target = f"unix:{uds_folder / socket_filename}"
# Set default authority to "localhost" for UDS connection
# This is needed to avoid issues with some gRPC implementations,
# see https://github.com/grpc/grpc/issues/34305
Expand Down Expand Up @@ -476,12 +490,17 @@ def verify_transport_mode(transport_mode: str, mode: str | None = None) -> None:
raise ValueError(f"Invalid transport mode: {transport_mode}. Valid options are: {', '.join(valid_modes)}.")


def verify_uds_socket(uds_service: str, uds_dir: Path | None = None, uds_id: str | None = None) -> bool:
def verify_uds_socket(
uds_service: str | None = None,
uds_dir: Path | None = None,
uds_id: str | None = None,
uds_fullpath: str | Path | None = None,
) -> bool:
"""Verify that the UDS socket file has been created.

Parameters
----------
uds_service : str
uds_service : str | None
Service name for the UDS socket.
uds_dir : Path | None
Directory where the UDS socket file is expected to be (optional).
Expand All @@ -490,17 +509,26 @@ def verify_uds_socket(uds_service: str, uds_dir: Path | None = None, uds_id: str
Unique identifier for the UDS socket (optional).
By default `None` and thus it will use "<uds_service>.sock".
Otherwise, the socket filename will be "<uds_service>-<uds_id>.sock".
uds_fullpath : str | Path | None
Full path to the UDS socket file.
By default `None` and thus it will use the `uds_service`, `uds_dir` and `uds_id` parameters.

Returns
-------
bool
True if the UDS socket file exists, False otherwise.
"""
# Generate socket filename with optional ID
uds_filename = f"{uds_service}-{uds_id}.sock" if uds_id else f"{uds_service}.sock"
if uds_fullpath:
return Path(uds_fullpath).exists()
else:
if uds_service is None:
raise ValueError("When using UDS transport mode, 'uds_service' must be provided.")

# Generate socket filename with optional ID
uds_filename = f"{uds_service}-{uds_id}.sock" if uds_id else f"{uds_service}.sock"

# Full path to the UDS socket file
uds_socket_path = determine_uds_folder(uds_dir) / uds_filename
# Full path to the UDS socket file
uds_socket_path = determine_uds_folder(uds_dir) / uds_filename

# Check if the UDS socket file exists
return uds_socket_path.exists()
# Check if the UDS socket file exists
return uds_socket_path.exists()
84 changes: 84 additions & 0 deletions tests/test_cyberchannel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (C) 2025 ANSYS, Inc. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Tests for cyberchannel."""

import os
from pathlib import Path
import tempfile

import pytest

from ansys.tools.common import cyberchannel


def test_version_tuple():
"""Test version tuple."""
assert cyberchannel.version_tuple("1.2.3") == (1, 2, 3)
assert cyberchannel.version_tuple("1.2.3.4") == (1, 2, 3, 4)
assert cyberchannel.version_tuple("1.0.0") == (1, 0, 0)


def test_cyberchannel_functions():
"""Test cyberchannel functions."""
assert cyberchannel.check_grpc_version()
assert cyberchannel.is_uds_supported()
uds_path = cyberchannel.determine_uds_folder()
uds_path.mkdir(parents=True, exist_ok=True)
assert uds_path.is_dir()
assert uds_path.exists()
uds_path.rmdir()
cyberchannel.verify_transport_mode(transport_mode="insecure", mode="local")
with pytest.raises(ValueError):
cyberchannel.verify_transport_mode(transport_mode="invalid_mode", mode="mode1")


def test_cyberchannel_insecure():
"""Test cyberchannel insecure."""
ch = cyberchannel.create_insecure_channel(host="localhost", port=12345)
assert ch is not None
assert ch._channel.target().decode() == "dns:///localhost:12345"
assert not ch.close()


@pytest.mark.skipif(os.name != "nt", reason="WNUA is only supported on Windows.")
def test_cyberchannel_wnua():
"""Test cyberchannel wnua."""
ch = cyberchannel.create_wnua_channel(host="localhost", port=12345)
assert ch is not None
assert ch._channel.target().decode() == "dns:///localhost:12345"
assert not ch.close()


def test_cyberchannel_uds():
"""Test cyberchannel uds."""
uds_file = Path(tempfile.gettempdir()) / "test_uds.sock"
with uds_file.open("w"):
pass
ch = cyberchannel.create_uds_channel(uds_fullpath=uds_file)
assert ch is not None
assert ch._channel.target().decode() == f"unix:{uds_file}"
assert not ch.close()

ch = cyberchannel.create_uds_channel("service_name")
assert ch is not None
assert ch._channel.target().decode() == f"unix:{cyberchannel.determine_uds_folder() / 'service_name.sock'}"
assert not ch.close()