|
1 | 1 | import importlib.util |
| 2 | +import os |
2 | 3 | import urllib.parse |
| 4 | +from typing import Any |
3 | 5 |
|
4 | 6 | import pyarrow as pa |
5 | 7 |
|
|
8 | 10 | from influxdb_client_3.write_client import InfluxDBClient as _InfluxDBClient, WriteOptions, Point |
9 | 11 | from influxdb_client_3.write_client.client.exceptions import InfluxDBError |
10 | 12 | from influxdb_client_3.write_client.client.write_api import WriteApi as _WriteApi, SYNCHRONOUS, ASYNCHRONOUS, \ |
11 | | - PointSettings |
| 13 | + PointSettings, DefaultWriteOptions, WriteType |
12 | 14 | from influxdb_client_3.write_client.domain.write_precision import WritePrecision |
13 | 15 |
|
14 | 16 | polars = importlib.util.find_spec("polars") is not None |
15 | 17 |
|
| 18 | +INFLUX_HOST = "INFLUX_HOST" |
| 19 | +INFLUX_TOKEN = "INFLUX_TOKEN" |
| 20 | +INFLUX_DATABASE = "INFLUX_DATABASE" |
| 21 | +INFLUX_ORG = "INFLUX_ORG" |
| 22 | +INFLUX_PRECISION = "INFLUX_PRECISION" |
| 23 | +INFLUX_AUTH_SCHEME = "INFLUX_AUTH_SCHEME" |
| 24 | + |
16 | 25 |
|
17 | 26 | def write_client_options(**kwargs): |
18 | 27 | """ |
@@ -84,6 +93,27 @@ def _merge_options(defaults, exclude_keys=None, custom=None): |
84 | 93 | return _deep_merge(defaults, {key: value for key, value in custom.items() if key not in exclude_keys}) |
85 | 94 |
|
86 | 95 |
|
| 96 | +def _parse_precision(precision): |
| 97 | + """ |
| 98 | + Parses the precision value and ensures it is valid. |
| 99 | +
|
| 100 | + This function checks that the given `precision` is one of the allowed |
| 101 | + values defined in `WritePrecision`. If the precision is invalid, it |
| 102 | + raises a `ValueError`. The function returns the valid precision value |
| 103 | + if it passes validation. |
| 104 | +
|
| 105 | + :param precision: The precision value to be validated. |
| 106 | + Must be one of WritePrecision.NS, WritePrecision.MS, |
| 107 | + WritePrecision.S, or WritePrecision.US. |
| 108 | + :return: The valid precision value. |
| 109 | + :rtype: WritePrecision |
| 110 | + :raises ValueError: If the provided precision is not valid. |
| 111 | + """ |
| 112 | + if precision not in [WritePrecision.NS, WritePrecision.MS, WritePrecision.S, WritePrecision.US]: |
| 113 | + raise ValueError(f"Invalid precision value: {precision}") |
| 114 | + return precision |
| 115 | + |
| 116 | + |
87 | 117 | class InfluxDBClient3: |
88 | 118 | def __init__( |
89 | 119 | self, |
@@ -137,8 +167,23 @@ def __init__( |
137 | 167 | self._org = org if org is not None else "default" |
138 | 168 | self._database = database |
139 | 169 | self._token = token |
140 | | - self._write_client_options = write_client_options if write_client_options is not None \ |
141 | | - else default_client_options(write_options=SYNCHRONOUS) |
| 170 | + |
| 171 | + write_type = DefaultWriteOptions.write_type.value |
| 172 | + write_precision = DefaultWriteOptions.write_precision.value |
| 173 | + if isinstance(write_client_options, dict) and write_client_options.get('write_options') is not None: |
| 174 | + write_opts = write_client_options['write_options'] |
| 175 | + write_type = getattr(write_opts, 'write_type', write_type) |
| 176 | + write_precision = getattr(write_opts, 'write_precision', write_precision) |
| 177 | + |
| 178 | + write_options = WriteOptions( |
| 179 | + write_type=write_type, |
| 180 | + write_precision=write_precision, |
| 181 | + ) |
| 182 | + |
| 183 | + self._write_client_options = { |
| 184 | + "write_options": write_options, |
| 185 | + **(write_client_options or {}) |
| 186 | + } |
142 | 187 |
|
143 | 188 | # Parse the host input |
144 | 189 | parsed_url = urllib.parse.urlparse(host) |
@@ -179,6 +224,39 @@ def __init__( |
179 | 224 | flight_client_options=flight_client_options, |
180 | 225 | proxy=kwargs.get("proxy", None), options=q_opts_builder.build()) |
181 | 226 |
|
| 227 | + @classmethod |
| 228 | + def from_env(cls, **kwargs: Any) -> 'InfluxDBClient3': |
| 229 | + |
| 230 | + required_vars = { |
| 231 | + INFLUX_HOST: os.getenv(INFLUX_HOST), |
| 232 | + INFLUX_TOKEN: os.getenv(INFLUX_TOKEN), |
| 233 | + INFLUX_DATABASE: os.getenv(INFLUX_DATABASE) |
| 234 | + } |
| 235 | + missing_vars = [var for var, value in required_vars.items() if value is None or value == ""] |
| 236 | + if missing_vars: |
| 237 | + raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}") |
| 238 | + |
| 239 | + write_options = WriteOptions(write_type=WriteType.synchronous) |
| 240 | + |
| 241 | + precision = os.getenv(INFLUX_PRECISION) |
| 242 | + if precision is not None: |
| 243 | + write_options.write_precision = _parse_precision(precision) |
| 244 | + |
| 245 | + write_client_option = {'write_options': write_options} |
| 246 | + |
| 247 | + if os.getenv(INFLUX_AUTH_SCHEME) is not None: |
| 248 | + kwargs['auth_scheme'] = os.getenv(INFLUX_AUTH_SCHEME) |
| 249 | + |
| 250 | + org = os.getenv(INFLUX_ORG, "default") |
| 251 | + return InfluxDBClient3( |
| 252 | + host=required_vars[INFLUX_HOST], |
| 253 | + token=required_vars[INFLUX_TOKEN], |
| 254 | + database=required_vars[INFLUX_DATABASE], |
| 255 | + write_client_options=write_client_option, |
| 256 | + org=org, |
| 257 | + **kwargs |
| 258 | + ) |
| 259 | + |
182 | 260 | def write(self, record=None, database=None, **kwargs): |
183 | 261 | """ |
184 | 262 | Write data to InfluxDB. |
|
0 commit comments