diff --git a/README.md b/README.md index 33c93c0..4ce0d53 100644 --- a/README.md +++ b/README.md @@ -115,3 +115,40 @@ if __name__ == "__main__": print(f"query failed, {error}") ``` + +Enable TLS (skip certificate authentication): + +```python +from opengemini_client import Client, Config, Address, TlsConfig + +if __name__ == "__main__": + config = Config(address=[Address(host='127.0.0.1', port=8443)], tls_enabled=True) + cli = Client(config) + try: + cli.ping(0) + print("ping success") + except Exception as error: + print(f"ping failed, {error}") + +``` + +Enable TLS (Certificate Authentication): + +```python +import ssl +from opengemini_client import Client, Config, Address, TlsConfig + +if __name__ == "__main__": + context = ssl.SSLContext() + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations("ca.crt") + config = Config(address=[Address(host='127.0.0.1', port=8443)], tls_enabled=True, + tls_config=TlsConfig(ca_file="ca.crt")) + cli = Client(config) + try: + cli.ping(0) + print("ping success") + except Exception as error: + print(f"ping failed, {error}") + +``` diff --git a/README_CN.md b/README_CN.md index f9bca27..480b1ac 100644 --- a/README_CN.md +++ b/README_CN.md @@ -116,3 +116,40 @@ if __name__ == "__main__": print(f"query failed, {error}") ``` + +开启TLS(跳过证书认证): + +```python +from opengemini_client import Client, Config, Address, TlsConfig + +if __name__ == "__main__": + config = Config(address=[Address(host='127.0.0.1', port=8443)], tls_enabled=True) + cli = Client(config) + try: + cli.ping(0) + print("ping success") + except Exception as error: + print(f"ping failed, {error}") + +``` + +开启TLS(证书认证): + +```python +import ssl +from opengemini_client import Client, Config, Address, TlsConfig + +if __name__ == "__main__": + context = ssl.SSLContext() + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations("ca.crt") + config = Config(address=[Address(host='127.0.0.1', port=8443)], tls_enabled=True, + tls_config=TlsConfig(ca_file="ca.crt")) + cli = Client(config) + try: + cli.ping(0) + print("ping success") + except Exception as error: + print(f"ping failed, {error}") + +``` diff --git a/opengemini_client/__init__.py b/opengemini_client/__init__.py index 6d0144b..222974f 100644 --- a/opengemini_client/__init__.py +++ b/opengemini_client/__init__.py @@ -30,6 +30,7 @@ RpConfig, Series, SeriesResult, + TlsConfig, ValuesResult ) diff --git a/opengemini_client/client_impl.py b/opengemini_client/client_impl.py index 0fac657..f8d4acc 100644 --- a/opengemini_client/client_impl.py +++ b/opengemini_client/client_impl.py @@ -29,7 +29,7 @@ from opengemini_client.models import Config, BatchPoints, Query, QueryResult, Series, SeriesResult, RpConfig, \ ValuesResult, KeyValue from opengemini_client.url_const import UrlConst -from opengemini_client.models import AuthType +from opengemini_client.models import AuthType, TlsConfig def check_config(config: Config): @@ -45,6 +45,9 @@ def check_config(config: Config): if config.auth_config.auth_type == AuthType.TOKEN and len(config.auth_config.token) == 0: raise ValueError("invalid auth config due to empty token") + if config.tls_enabled and config.tls_config is None: + config.tls_config = TlsConfig() + if config.batch_config is not None: if config.batch_config.batch_interval <= 0: raise ValueError("batch enabled,batch interval must be greater than 0") @@ -85,7 +88,11 @@ class OpenGeminiDBClient(Client, ABC): def __init__(self, config: Config): self.config = check_config(config) self.session = requests.Session() - protocol = "https://" if config.tls_enabled else "http://" + protocol = "http://" + if config.tls_enabled: + protocol = "https://" + self.session.cert = (config.tls_config.cert_file, config.tls_config.key_file) + self.session.verify = config.tls_config.ca_file self.endpoints = [f"{protocol}{addr.host}:{addr.port}" for addr in config.address] self.endpoints_iter = itertools.cycle(self.endpoints) diff --git a/opengemini_client/models.py b/opengemini_client/models.py index db84cd1..a380d70 100644 --- a/opengemini_client/models.py +++ b/opengemini_client/models.py @@ -13,7 +13,6 @@ # limitations under the License. import io -import ssl from dataclasses import field, dataclass from datetime import datetime, timedelta from enum import Enum @@ -39,6 +38,13 @@ class AuthConfig: token: str = '' +@dataclass +class TlsConfig: + cert_file: str = '' + key_file: str = '' + ca_file: str = '' + + @dataclass class BatchConfig: batch_interval: int @@ -54,7 +60,7 @@ class Config: gzip_enabled: bool = False tls_enabled: bool = False auth_config: AuthConfig = None - tls_config: ssl.SSLContext = None + tls_config: TlsConfig = None @dataclass