|
| 1 | +# Copyright (C) 2025 ANSYS, Inc. and/or its affiliates. |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | +# |
| 4 | +# |
| 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 6 | +# of this software and associated documentation files (the "Software"), to deal |
| 7 | +# in the Software without restriction, including without limitation the rights |
| 8 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 9 | +# copies of the Software, and to permit persons to whom the Software is |
| 10 | +# furnished to do so, subject to the following conditions: |
| 11 | +# |
| 12 | +# The above copyright notice and this permission notice shall be included in all |
| 13 | +# copies or substantial portions of the Software. |
| 14 | +# |
| 15 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 16 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 17 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 18 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 19 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 20 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 21 | +# SOFTWARE. |
| 22 | + |
| 23 | +"""Defines options for connecting to a gRPC server.""" |
| 24 | + |
| 25 | +from abc import ABC, abstractmethod |
| 26 | +from dataclasses import asdict, dataclass |
| 27 | +import enum |
| 28 | +from pathlib import Path |
| 29 | +from typing import TYPE_CHECKING, Any, ClassVar |
| 30 | + |
| 31 | +import grpc |
| 32 | + |
| 33 | +from .. import cyberchannel |
| 34 | + |
| 35 | +__all__ = [ |
| 36 | + "TransportMode", |
| 37 | + "UDSOptions", |
| 38 | + "WNUAOptions", |
| 39 | + "MTLSOptions", |
| 40 | + "InsecureOptions", |
| 41 | + "TransportOptionsType", |
| 42 | +] |
| 43 | + |
| 44 | +# For Python 3.10 and below, emulate the behavior of StrEnum by |
| 45 | +# inheriting from str and enum.Enum. |
| 46 | +# Note that this does *not* work on Python 3.11+, since the default |
| 47 | +# Enum format method has changed and will not return the value of |
| 48 | +# the enum member. |
| 49 | +# When type checking, always use the Python 3.10 workaround, otherwise |
| 50 | +# the StrEnum resolves as 'Any'. |
| 51 | +if TYPE_CHECKING: # pragma: no cover |
| 52 | + |
| 53 | + class StrEnum(str, enum.Enum): |
| 54 | + """String enum.""" |
| 55 | + |
| 56 | +else: |
| 57 | + try: |
| 58 | + from enum import StrEnum |
| 59 | + except ImportError: |
| 60 | + import enum |
| 61 | + |
| 62 | + class StrEnum(str, enum.Enum): |
| 63 | + """String enum.""" |
| 64 | + |
| 65 | + pass |
| 66 | + |
| 67 | + |
| 68 | +class TransportMode(StrEnum): |
| 69 | + """Enumeration of transport modes supported by the FileTransfer Tool.""" |
| 70 | + |
| 71 | + UDS = "uds" |
| 72 | + WNUA = "wnua" |
| 73 | + MTLS = "mtls" |
| 74 | + INSECURE = "insecure" |
| 75 | + |
| 76 | + |
| 77 | +class TransportOptionsBase(ABC): |
| 78 | + """Base class for transport options.""" |
| 79 | + |
| 80 | + _MODE: ClassVar[TransportMode] |
| 81 | + |
| 82 | + @property |
| 83 | + def mode(self) -> TransportMode: |
| 84 | + """Transport mode.""" |
| 85 | + return self._MODE |
| 86 | + |
| 87 | + def create_channel(self, **extra_kwargs: Any) -> grpc.Channel: |
| 88 | + """Create a gRPC channel using the transport options. |
| 89 | +
|
| 90 | + Parameters |
| 91 | + ---------- |
| 92 | + extra_kwargs : |
| 93 | + Extra keyword arguments to pass to the channel creation function. |
| 94 | +
|
| 95 | + Returns |
| 96 | + ------- |
| 97 | + : |
| 98 | + gRPC channel created using the transport options. |
| 99 | + """ |
| 100 | + return cyberchannel.create_channel(**self._to_cyberchannel_kwargs(), **extra_kwargs) |
| 101 | + |
| 102 | + @abstractmethod |
| 103 | + def _to_cyberchannel_kwargs(self) -> dict[str, Any]: |
| 104 | + """Convert transport options to cyberchannel keyword arguments. |
| 105 | +
|
| 106 | + Returns |
| 107 | + ------- |
| 108 | + : |
| 109 | + Dictionary of keyword arguments for cyberchannel. |
| 110 | + """ |
| 111 | + pass |
| 112 | + |
| 113 | + |
| 114 | +@dataclass(kw_only=True) |
| 115 | +class UDSOptions(TransportOptionsBase): |
| 116 | + """Options for UDS transport mode.""" |
| 117 | + |
| 118 | + _MODE = TransportMode.UDS |
| 119 | + |
| 120 | + uds_service: str |
| 121 | + uds_dir: str | Path | None = None |
| 122 | + uds_id: str | None = None |
| 123 | + |
| 124 | + def _to_cyberchannel_kwargs(self) -> dict[str, Any]: |
| 125 | + return asdict(self) | {"transport_mode": self.mode.value} |
| 126 | + |
| 127 | + |
| 128 | +@dataclass(kw_only=True) |
| 129 | +class WNUAOptions(TransportOptionsBase): |
| 130 | + """Options for WNUA transport mode.""" |
| 131 | + |
| 132 | + _MODE = TransportMode.WNUA |
| 133 | + |
| 134 | + port: int |
| 135 | + |
| 136 | + def _to_cyberchannel_kwargs(self) -> dict[str, Any]: |
| 137 | + return asdict(self) | {"transport_mode": self.mode.value, "host": "localhost"} |
| 138 | + |
| 139 | + |
| 140 | +@dataclass(kw_only=True) |
| 141 | +class MTLSOptions(TransportOptionsBase): |
| 142 | + """Options for mTLS transport mode.""" |
| 143 | + |
| 144 | + _MODE = TransportMode.MTLS |
| 145 | + |
| 146 | + certs_dir: str | Path | None = None |
| 147 | + host: str = "localhost" |
| 148 | + port: int |
| 149 | + allow_remote_host: bool = False |
| 150 | + |
| 151 | + def _to_cyberchannel_kwargs(self) -> dict[str, Any]: |
| 152 | + if not self.allow_remote_host: |
| 153 | + if self.host not in ("localhost", "127.0.0.1"): |
| 154 | + raise ValueError(f"Remote host '{self.host}' specified without setting 'allow_remote_host=True'.") |
| 155 | + res = asdict(self) |
| 156 | + res.pop("allow_remote_host", None) |
| 157 | + return res | {"transport_mode": self.mode.value} |
| 158 | + |
| 159 | + |
| 160 | +@dataclass(kw_only=True) |
| 161 | +class InsecureOptions(TransportOptionsBase): |
| 162 | + """Options for insecure transport mode.""" |
| 163 | + |
| 164 | + _MODE = TransportMode.INSECURE |
| 165 | + |
| 166 | + host: str = "localhost" |
| 167 | + port: int |
| 168 | + allow_remote_host: bool = False |
| 169 | + |
| 170 | + def _to_cyberchannel_kwargs(self) -> dict[str, Any]: |
| 171 | + if not self.allow_remote_host: |
| 172 | + if self.host not in ("localhost", "127.0.0.1"): |
| 173 | + raise ValueError(f"Remote host '{self.host}' specified without setting 'allow_remote_host=True'.") |
| 174 | + res = asdict(self) |
| 175 | + res.pop("allow_remote_host", None) |
| 176 | + return res | {"transport_mode": self.mode.value} |
| 177 | + |
| 178 | + |
| 179 | +TransportOptionsType = UDSOptions | WNUAOptions | MTLSOptions | InsecureOptions |
0 commit comments