diff --git a/opengemini_client/__init__.py b/opengemini_client/__init__.py index a9badfd..6d0144b 100644 --- a/opengemini_client/__init__.py +++ b/opengemini_client/__init__.py @@ -32,3 +32,13 @@ SeriesResult, ValuesResult ) + +from .measurement import ( + FieldType, + ShardType, + IndexType, + EngineType, + ComparisonOperator, + Measurement, + MeasurementCondition +) diff --git a/opengemini_client/client.py b/opengemini_client/client.py index 9af877e..d260068 100644 --- a/opengemini_client/client.py +++ b/opengemini_client/client.py @@ -19,6 +19,7 @@ from typing import List from opengemini_client.models import BatchPoints, QueryResult, Query, RpConfig, ValuesResult +from opengemini_client.measurement import Measurement, MeasurementCondition class Client(ABC): @@ -85,6 +86,18 @@ def show_retention_policies(self, dbname): def drop_retention_policy(self, dbname, retention_policy: str): pass + @abstractmethod + def create_measurement(self, measurement: Measurement): + pass + + @abstractmethod + def show_measurements(self, condition: MeasurementCondition) -> List[str]: + pass + + @abstractmethod + def drop_measurement(self, database: str, retention_policy: str, measurement: str): + pass + @abstractmethod def show_tag_keys(self, database, command: str) -> List[ValuesResult]: """ diff --git a/opengemini_client/client_impl.py b/opengemini_client/client_impl.py index e172f03..0fac657 100644 --- a/opengemini_client/client_impl.py +++ b/opengemini_client/client_impl.py @@ -25,6 +25,7 @@ from requests import HTTPError from opengemini_client.client import Client +from opengemini_client.measurement import Measurement, MeasurementCondition from opengemini_client.models import Config, BatchPoints, Query, QueryResult, Series, SeriesResult, RpConfig, \ ValuesResult, KeyValue from opengemini_client.url_const import UrlConst @@ -97,10 +98,10 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_val, _exc_tb): self.session.close() - def get_server_url(self): + def _get_server_url(self): return next(self.endpoints_iter) - def update_headers(self, method, url_path, headers=None) -> dict: + def _update_headers(self, method, url_path, headers=None) -> dict: if headers is None: headers = {} @@ -121,10 +122,10 @@ def update_headers(self, method, url_path, headers=None) -> dict: return headers - def request(self, method, server_url, url_path, headers=None, body=None, params=None) -> requests.Response: + def _request(self, method, server_url, url_path, headers=None, body=None, params=None) -> requests.Response: if params is None: params = {} - headers = self.update_headers(method, url_path, headers) + headers = self._update_headers(method, url_path, headers) full_url = server_url + url_path if self.config.gzip_enabled and body is not None: compressed = io.BytesIO() @@ -139,36 +140,37 @@ def request(self, method, server_url, url_path, headers=None, body=None, params= raise HTTPError(f"request error resp, code: {resp.status_code}, body: {resp.text}") return resp - def exec_http_request_by_index(self, idx, method, url_path, headers=None, body=None) -> requests.Response: + def _exec_http_request_by_index(self, idx, method, url_path, headers=None, body=None) -> requests.Response: if idx >= len(self.endpoints) or idx < 0: raise ValueError("openGeminiDB client error. Index out of range") - return self.request(method, self.endpoints[idx], url_path, headers, body) + return self._request(method, self.endpoints[idx], url_path, headers, body) def ping(self, idx: int): - resp = self.exec_http_request_by_index(idx, 'GET', UrlConst.PING) + resp = self._exec_http_request_by_index(idx, 'GET', UrlConst.PING) if resp.status_code != HTTPStatus.NO_CONTENT: raise HTTPError(f"ping error resp, code: {resp.status_code}, body: {resp.text}") def query(self, query: Query) -> QueryResult: - server_url = self.get_server_url() - params = {'db': query.database, 'q': query.command, 'rp': query.retention_policy} + server_url = self._get_server_url() + params = {'db': query.database, 'q': query.command, 'rp': query.retention_policy, + 'epoch': query.precision.epoch()} - resp = self.request(method='GET', server_url=server_url, url_path=UrlConst.QUERY, params=params) + resp = self._request(method='GET', server_url=server_url, url_path=UrlConst.QUERY, params=params) if resp.status_code == HTTPStatus.OK: return resolve_query_body(resp) raise HTTPError(f"query error resp, code: {resp.status_code}, body: {resp.text}") def _query_post(self, query: Query) -> QueryResult: - server_url = self.get_server_url() + server_url = self._get_server_url() params = {'db': query.database, 'q': query.command, 'rp': query.retention_policy} - resp = self.request(method='POST', server_url=server_url, url_path=UrlConst.QUERY, params=params) + resp = self._request(method='POST', server_url=server_url, url_path=UrlConst.QUERY, params=params) if resp.status_code == HTTPStatus.OK: return resolve_query_body(resp) raise HTTPError(f"query_post error resp, code: {resp.status_code}, body: {resp.text}") def write_batch_points(self, database: str, batch_points: BatchPoints): - server_url = self.get_server_url() + server_url = self._get_server_url() params = {'db': database} with io.StringIO() as writer: for point in batch_points.points: @@ -177,7 +179,7 @@ def write_batch_points(self, database: str, batch_points: BatchPoints): writer.write(point.to_string()) writer.write('\n') body = writer.getvalue().encode() - resp = self.request(method="POST", server_url=server_url, url_path=UrlConst.WRITE, params=params, body=body) + resp = self._request(method="POST", server_url=server_url, url_path=UrlConst.WRITE, params=params, body=body) if resp.status_code == HTTPStatus.NO_CONTENT: return raise HTTPError(f"write_batch_points error resp, code: {resp.status_code}, body: {resp.text}") @@ -279,6 +281,39 @@ def _show_with_result_key_value(self, database, command: str) -> List[ValuesResu values_results.append(values_result) return values_results + def create_measurement(self, measurement: Measurement): + if measurement is None: + raise ValueError("empty measurement") + measurement.check() + command = measurement.to_string() + return self._query_post(Query(database=measurement.database, command=command, retention_policy='')) + + def show_measurements(self, condition: MeasurementCondition) -> List[str]: + if condition is None: + raise ValueError("empty measurement condition") + condition.check() + command = condition.to_string() + result = self.query(Query(database=condition.database, command=command, retention_policy='')) + if result.error is not None: + raise HTTPError(f"show_measurements error result, error: {result.error}") + measurements = [] + if len(result.results) == 0 or len(result.results[0].series) == 0: + return measurements + if result.results[0].error is not None: + raise HTTPError(f"show_measurements error result, error: {result.results[0].error}") + for v in result.results[0].series[0].values: + if isinstance(v[0], str): + measurements.append(str(v[0])) + return measurements + + def drop_measurement(self, database: str, retention_policy: str, measurement: str): + if not database: + raise ValueError("empty database name") + if not measurement: + raise ValueError("empty measurement name") + command = f"DROP MEASUREMENT {measurement}" + return self._query_post(Query(database=database, command=command, retention_policy=retention_policy)) + def show_tag_keys(self, database, command: str) -> List[ValuesResult]: return self._show_with_result_any(database, command) diff --git a/opengemini_client/measurement.py b/opengemini_client/measurement.py new file mode 100644 index 0000000..0c62bb7 --- /dev/null +++ b/opengemini_client/measurement.py @@ -0,0 +1,170 @@ +# Copyright 2025 openGemini Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Dict + +ErrEmptyDatabaseName = "empty database name" +ErrEmptyMeasurement = "empty measurement" +ErrEmptyTagOrField = "empty tag or field" +ErrEmptyIndexList = "empty index list" + + +class FieldType(Enum): + FieldTypeInt64 = "INT64" + FieldTypeFloat64 = "FLOAT64" + FieldTypeString = "STRING" + FieldTypeBool = "BOOL" + + +class ShardType(Enum): + ShardTypeHash = "HASH" + ShardTypeRange = "RANGE" + + +class IndexType(Enum): + IndexTypeText = "text" + + +class EngineType(Enum): + EngineTypeColumnstore = "columnstore" + + +class ComparisonOperator(Enum): + Equals = "=" + NotEquals = "<>" + GreaterThan = ">" + LessThan = "<" + GreaterThanOrEquals = ">=" + LessThanOrEquals = "<=" + Match = "=~" + NotMatch = "!~" + + +@dataclass +class Measurement: + database: str + measurement: str + retention_policy: str + # specify tag list to create measurement + tags: List[str] = field(default_factory=list) + # specify field map to create measurement + fields: Dict[str, FieldType] = field(default_factory=dict) + # specify shard type to create measurement, support ShardTypeHash and ShardTypeRange two ways to + # break up data, required when use high series cardinality storage engine(HSCE) + shard_type: ShardType = None + # specify shard keys(tag as partition key) to create measurement, required when use + # high series cardinality storage engine(HSCE) + shard_keys: List[str] = field(default_factory=list) + # FullTextIndex required when want measurement support full-text index + index_type: IndexType = None + # required when specify which Field fields to create a full-text index on, + # these fields must be 'string' data type + index_list: List[str] = field(default_factory=list) + # required when want measurement support HSCE, set EngineTypeColumnStore + engine_type: EngineType = None + # required when use HSCE, such as the primary key is `location` and `direction`, which means that the + # storage engine will create indexes on these two fields + primary_keys: List[str] = field(default_factory=list) + # required when use HSCE, specify the data sorting method inside the storage engine, time means sorting + # by time, and can also be changed to rtt or direction, or even other fields in the table + sort_keys: List[str] = field(default_factory=list) + + def check(self): + if len(self.database) == 0: + raise ValueError(ErrEmptyDatabaseName) + if len(self.measurement) == 0: + raise ValueError(ErrEmptyMeasurement) + if len(self.tags) == 0 and len(self.fields) == 0: + raise ValueError(ErrEmptyTagOrField) + if self.index_type is not None and len(self.index_list) == 0: + raise ValueError(ErrEmptyIndexList) + + def _write_tags_fields(self, writer: io.StringIO): + writer.write(f"CREATE MEASUREMENT {self.measurement} (") + if len(self.tags) != 0: + tags = [] + for tag in self.tags: + tags.append(f"{tag} TAG") + writer.write(",".join(tags)) + if len(self.tags) != 0 and len(self.fields) != 0: + writer.write(",") + if len(self.fields) != 0: + fields = [] + for key, value in self.fields.items(): + fields.append(f"{key} {value.value} FIELD") + writer.write(",".join(fields)) + writer.write(")") + + def _write_index(self, writer: io.StringIO): + writer.write(" WITH ") + writer.write(f" INDEXTYPE {self.index_type.value}") + writer.write(" INDEXLIST " + ",".join(self.index_list)) + + def _writer_other(self, writer: io.StringIO): + with_identifier = False + if self.engine_type is not None: + with_identifier = True + writer.write(" WITH ") + writer.write(f" ENGINETYPE = {self.engine_type.value}") + if len(self.shard_keys) != 0: + if with_identifier is False: + with_identifier = True + writer.write(" WITH ") + writer.write(" SHARDKEY " + ",".join(self.shard_keys)) + if self.shard_type is not None: + if with_identifier is False: + with_identifier = True + writer.write(" WITH ") + writer.write(f" TYPE {self.shard_type.value}") + if len(self.primary_keys) != 0: + if with_identifier is False: + with_identifier = True + writer.write(" WITH ") + writer.write(" PRIMARYKEY " + ",".join(self.primary_keys)) + if len(self.sort_keys) != 0: + if with_identifier is False: + writer.write(" WITH ") + writer.write(" SORTKEY " + ",".join(self.sort_keys)) + + def to_string(self) -> str: + writer = io.StringIO() + self._write_tags_fields(writer) + + if self.index_type is not None: + self._write_index(writer) + return writer.getvalue() + + self._writer_other(writer) + return writer.getvalue() + + +@dataclass +class MeasurementCondition: + database: str + Operator: ComparisonOperator = None + Value: str = '' + + def check(self): + if len(self.database) == 0: + raise ValueError(ErrEmptyDatabaseName) + + def to_string(self) -> str: + writer = io.StringIO() + writer.write("SHOW MEASUREMENTS") + if self.Operator is not None: + writer.write(f" WITH MEASUREMENT {self.Operator.value} {self.Value}") + return writer.getvalue() diff --git a/opengemini_client/measurement_test.py b/opengemini_client/measurement_test.py new file mode 100644 index 0000000..d515412 --- /dev/null +++ b/opengemini_client/measurement_test.py @@ -0,0 +1,156 @@ +# Copyright 2025 openGemini Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import time + +from opengemini_client import test_utils +from opengemini_client.measurement import Measurement, FieldType, IndexType, ErrEmptyTagOrField, ErrEmptyDatabaseName, \ + ErrEmptyMeasurement, ErrEmptyIndexList, MeasurementCondition, ComparisonOperator + + +class MeasurementTest(unittest.TestCase): + + def test_create_measurement_success(self): + with test_utils.get_test_default_client() as cli: + cli.create_database('m_test') + cli.create_measurement(measurement=Measurement( + database="m_test", + measurement="m1", + retention_policy="", + tags=["tag1", "tag2"], + fields={ + "field_str": FieldType.FieldTypeString, + "field_int64": FieldType.FieldTypeInt64, + }, + index_type=IndexType.IndexTypeText, + index_list=["field_str"] + )) + time.sleep(5) + field_results = cli.show_field_keys('m_test', 'show field keys from m1') + print(field_results) + self.assertEqual(2, len(field_results[0].values)) + cli.drop_database("m_test") + + def test_create_measurement_failed(self): + with test_utils.get_test_default_client() as cli: + # measurement is None + with self.assertRaises(ValueError) as context: + cli.create_measurement(measurement=None) + print(context.exception) + self.assertRegex(str(context.exception), "empty measurement") + + # database is None + with self.assertRaises(ValueError) as context: + cli.create_measurement(measurement=Measurement( + database="", + measurement="m1", + retention_policy="", + )) + print(context.exception) + self.assertRegex(str(context.exception), ErrEmptyDatabaseName) + + # measurement is None + with self.assertRaises(ValueError) as context: + cli.create_measurement(measurement=Measurement( + database="m_test", + measurement="", + retention_policy="", + )) + print(context.exception) + self.assertRegex(str(context.exception), ErrEmptyMeasurement) + + # tags fields is None + with self.assertRaises(ValueError) as context: + cli.create_measurement(measurement=Measurement( + database="m_test", + measurement="m1", + retention_policy="", + )) + print(context.exception) + self.assertRegex(str(context.exception), ErrEmptyTagOrField) + + with self.assertRaises(ValueError) as context: + cli.create_measurement(measurement=Measurement( + database="m_test", + measurement="m1", + retention_policy="", + tags=["tag1", "tag2"], + fields={ + "field_str": FieldType.FieldTypeString, + "field_int64": FieldType.FieldTypeInt64, + }, + index_type=IndexType.IndexTypeText, + )) + print(context.exception) + self.assertRegex(str(context.exception), ErrEmptyIndexList) + + def test_show_measurements_success(self): + with test_utils.get_test_default_client() as cli: + cli.create_database('m_show_test') + cli.create_measurement(measurement=Measurement( + database="m_show_test", + measurement="m1", + retention_policy="", + tags=["tag1"], + fields={ + "field_str": FieldType.FieldTypeString, + }, + index_type=IndexType.IndexTypeText, + index_list=["field_str"] + )) + cli.create_measurement(measurement=Measurement( + database="m_show_test", + measurement="m2", + retention_policy="", + tags=["tag2"], + fields={ + "field_str": FieldType.FieldTypeString, + }, + index_type=IndexType.IndexTypeText, + index_list=["field_str"] + )) + time.sleep(5) + ms = cli.show_measurements(condition=MeasurementCondition( + database="m_show_test", + Operator=ComparisonOperator.Match, + Value="/m/" + )) + print(ms) + self.assertEqual(2, len(ms)) + cli.drop_database("m_show_test") + + def test_drop_measurement_success(self): + with test_utils.get_test_default_client() as cli: + cli.create_database('m_test') + cli.create_measurement(measurement=Measurement( + database="m_test", + measurement="m1", + retention_policy="", + tags=["tag1"], + fields={ + "field_str": FieldType.FieldTypeString, + }, + index_type=IndexType.IndexTypeText, + index_list=["field_str"] + )) + cli.drop_measurement("m_test", "", "m1") + ms = cli.show_measurements(condition=MeasurementCondition( + database="m_test", + Operator=ComparisonOperator.Match, + Value="/m/" + )) + print(ms) + self.assertEqual(0, len(ms)) + cli.drop_database("m_test") diff --git a/opengemini_client/models.py b/opengemini_client/models.py index 3ba7fe1..db84cd1 100644 --- a/opengemini_client/models.py +++ b/opengemini_client/models.py @@ -85,6 +85,23 @@ class Precision(Enum): PrecisionMinute = 4 PrecisionHour = 5 + def epoch(self) -> str: + if self == Precision.PrecisionNanoSecond: + unit = 'ns' + elif self == Precision.PrecisionMicrosecond: + unit = 'u' + elif self == Precision.PrecisionMillisecond: + unit = 'ms' + elif self == Precision.PrecisionSecond: + unit = 's' + elif self == Precision.PrecisionMinute: + unit = 'm' + elif self == Precision.PrecisionHour: + unit = 'h' + else: + unit = '' + return unit + def round_datetime(dt: datetime, round_to: timedelta): if round_to.seconds == 0: @@ -214,6 +231,7 @@ class Query: database: str command: str retention_policy: str + precision: Precision = Precision.PrecisionNanoSecond @dataclass